Skip to content

Commit

Permalink
add unit test agent utils for reward calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
p3nGu1nZz committed Apr 8, 2024
1 parent 6e85f42 commit 4fab0d5
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 54 deletions.
40 changes: 40 additions & 0 deletions Assets/Scripts/Agents/AgentUtils.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using System;
using UnityEngine;

namespace DialogosEngine
{
public static class AgentUtils
{
public const string k_EndOfSequence = "<eos>";

public static float CalculateEchoReward(string expectedString, string guessedString)
{
// Validate input strings
if (string.IsNullOrEmpty(expectedString) || string.IsNullOrEmpty(guessedString))
{
throw new ArgumentException("Expected and guessed strings must not be null or empty.");
}
if (!expectedString.EndsWith(k_EndOfSequence))
{
throw new ArgumentException("Expected string must end with '<eos>'.");
}

// Calculate Levenshtein distance
int levenshteinDistance = Lexer.LevenshteinDistance(expectedString, guessedString);
float maxStringLength = Mathf.Max(expectedString.Length, guessedString.Length);
float similarityScore = 1f - (float)levenshteinDistance / maxStringLength;
similarityScore = (similarityScore * 2f) - 1f; // Normalize to range [-1, 1]

// Calculate length match score
float lengthDifference = Mathf.Abs(expectedString.Length - guessedString.Length);
float lengthMatchScore = 1f - Mathf.Min(2f * lengthDifference / maxStringLength, 1f);
lengthMatchScore = (lengthMatchScore * 2f) - 1f; // Normalize to range [-1, 1]

// Combine similarity and length match scores
float combinedScore = (0.5f * similarityScore) + (0.5f * lengthMatchScore);

// Ensure the final score is within the range [-1, 1]
return Mathf.Clamp(combinedScore, -1f, 1f);
}
}
}
2 changes: 2 additions & 0 deletions Assets/Scripts/Agents/AgentUtils.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 14 additions & 52 deletions Assets/Scripts/Agents/EchoAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,40 @@
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;

namespace DialogosEngine
{
public class EchoAgent : Agent
{
string _CachedGuessedString;
const string k_EndOfSequence = "<eos>";
string _CachedString;

public override void OnEpisodeBegin()
{
ClearConsole();
Terminal.Instance.Buffer.Reset();
}

public void FixedUpdate()
{
string expectedString = GetExpectedString();
if (_CachedGuessedString != null)
if (_CachedString != null)
{
float _reward;
if (_CachedGuessedString.EndsWith(k_EndOfSequence))
float reward = AgentUtils.CalculateEchoReward(expectedString, _CachedString);

if (_CachedString.EndsWith(AgentUtils.k_EndOfSequence))
{
_reward = CalculateReward(expectedString, _CachedGuessedString);

_CachedGuessedString = _CachedGuessedString.Replace(k_EndOfSequence, "");
_CachedString = _CachedString.Replace(AgentUtils.k_EndOfSequence, "");

Terminal.Instance.Shell.Run(_CachedGuessedString);

if(Terminal.Instance.IssuedError)
{
_reward -= 0.5f; // Penalize for bad commands
}
}
else
{
_reward = -1f;
Terminal.Instance.Shell.Run(_CachedString);
}
SetReward(_reward);
_CachedGuessedString = null;

SetReward(reward);
_CachedString = null;
}
}

public override void CollectObservations(VectorSensor sensor)
{
string _buffer = GetConsoleBuffer();
string _buffer = Terminal.Instance.Buffer.GetLastLog();
float[] _vectorizedBuffer = Lexer.VectorizeUTF8(_buffer);
foreach (var obs in _vectorizedBuffer)
{
Expand All @@ -57,39 +46,12 @@ public override void CollectObservations(VectorSensor sensor)
public override void OnActionReceived(ActionBuffers actions)
{
float[] _actionArray = actions.ContinuousActions.Array;
_CachedGuessedString = Lexer.QuantizeUTF8(_actionArray);
}

private void ClearConsole()
{
Terminal.Instance.Buffer.Reset();
}

private float CalculateReward(string expectedString, string guessedString)
{
int levenshteinDistance = Lexer.LevenshteinDistance(expectedString, guessedString);
float maxStringLength = Mathf.Max(expectedString.Length, guessedString.Length);
float similarityScore = 1f - (float)levenshteinDistance / maxStringLength;
similarityScore = (similarityScore * 2f) - 1f; // Normalize to range [-1, 1]

float lengthDifference = Mathf.Abs(expectedString.Length - guessedString.Length);
float lengthMatchScore = 1f - Mathf.Min(2f * lengthDifference / maxStringLength, 1f);
lengthMatchScore = (lengthMatchScore * 2f) - 1f; // Normalize to range [-1, 1]

float combinedScore = (0.5f * similarityScore) + (0.5f * lengthMatchScore);
return Mathf.Clamp(combinedScore, -1f, 1f); // Ensure final score is within [-1, 1]
}


private string GetConsoleBuffer()
{
return Terminal.Instance.Buffer.GetLastLog();
_CachedString = Lexer.QuantizeUTF8(_actionArray);
}

private string GetExpectedString()
{
// Implementation to get the expected string for the current step
return "";
return "echo hello <eos>"; // Testing
}
}
}
166 changes: 166 additions & 0 deletions Tests/AgentUtilsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
namespace DialogosEngine.Tests
{
[TestFixture]
public class AgentUtilsTests
{
[Test]
public void CalculateEchoReward_ShouldReturnPerfectScore_WhenStringsMatchExactly()
{
// Arrange
string expected = "hello<eos>";
string guessed = "hello<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string: '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
Assert.That(reward, Is.EqualTo(1f), "The reward should be 1.0f when the expected and guessed strings match exactly.");
TestContext.WriteLine("Test passed: The calculated reward is 1.0f as expected.");
}

[Test]
public void CalculateEchoReward_ShouldThrowArgumentException_WhenStringsAreEmpty()
{
// Arrange
string expected = "";
string guessed = "";
TestContext.WriteLine($"Testing CalculateEchoReward with empty strings.");

// Act & Assert
var ex = Assert.Throws<ArgumentException>(() => AgentUtils.CalculateEchoReward(expected, guessed));
TestContext.WriteLine($"Expected ArgumentException was thrown with message: {ex.Message}");

// Log the result
Assert.IsNotNull(ex, "An ArgumentException should be thrown for empty strings.");
TestContext.WriteLine("Test passed: ArgumentException is thrown as expected for empty strings.");
}

[Test]
public void CalculateEchoReward_ShouldThrowArgumentException_WhenExpectedStringDoesNotEndWithEos()
{
// Arrange
string expected = "hello";
string guessed = "hello<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string not ending with '<eos>': '{expected}'.");

// Act & Assert
var ex = Assert.Throws<ArgumentException>(() => AgentUtils.CalculateEchoReward(expected, guessed));
TestContext.WriteLine($"Expected ArgumentException was thrown with message: {ex.Message}");

// Log the result
Assert.IsNotNull(ex, "An ArgumentException should be thrown when the expected string does not end with '<eos>'.");
TestContext.WriteLine("Test passed: ArgumentException is thrown as expected when the expected string does not end with '<eos>'.");
}

[Test]
public void CalculateEchoReward_ShouldReturnPartialScore_WhenStringsAreSimilar()
{
// Arrange
string expected = "hello<eos>";
string guessed = "hallo<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string: '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
Assert.That(reward, Is.EqualTo(0.9f).Within(0.01f), "The reward should be approximately 0.9f for similar strings.");
TestContext.WriteLine("Test passed: The calculated reward is 0.9f as expected for similar strings.");
}

[Test]
public void CalculateEchoReward_ShouldReflectSimilarityAndLengthMatch_WhenEosIsMissing()
{
// Arrange
string expected = "hello<eos>";
string guessed = "hello";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string missing '<eos>': '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
Assert.That(reward, Is.EqualTo(-0.5f).Within(0.01f), "The reward should be -0.5f when the guessed string is missing the '<eos>' token, reflecting the similarity and length match.");
TestContext.WriteLine("Test passed: The calculated reward is -0.5f as expected, reflecting the similarity and length match.");
}

[Test]
public void CalculateEchoReward_ShouldReturnSpecificLowerScore_WhenGuessedStringHasAdditionalCharacters()
{
// Arrange
string expected = "hello<eos>";
string guessed = "hello there<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and guessed string with additional characters: '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
Assert.That(reward, Is.EqualTo(-0.125f).Within(0.01f), "The reward should be -0.125f when the guessed string has additional characters.");
TestContext.WriteLine("Test passed: The calculated reward is -0.125f as expected when the guessed string has additional characters.");
}

[Test]
public void CalculateEchoReward_ShouldReturnSpecificMuchLowerScore_WhenGuessedStringIsSignificantlyLonger()
{
// Arrange
string expected = "hello<eos>";
string guessed = "hello there, how are you doing today?<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and a significantly longer guessed string: '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
Assert.That(reward, Is.EqualTo(-0.7619048f).Within(0.01f), "The reward should be -0.7619048f when the guessed string is significantly longer than the expected string.");
TestContext.WriteLine("Test passed: The calculated reward is -0.7619048f as expected when the guessed string is significantly longer than the expected string.");
}

[Test]
public void CalculateEchoReward_ShouldReturnPositiveScore_WhenGuessedStringIsLongerButSimilar()
{
// Arrange
string expected = "hello fr <eos>";
string guessed = "hello friend<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string: '{expected}' and a longer but similar guessed string: '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
// The expected reward should be positive due to the high similarity despite the additional length.
// Update the expected value based on the logic of your CalculateEchoReward method.
// For example, if the expected reward is now 0.2f based on the new logic, update the test accordingly:
Assert.That(reward, Is.GreaterThan(0f), "The reward should be positive when the guessed string is longer but still similar to the expected string.");
TestContext.WriteLine($"Test passed: The calculated reward is positive as expected when the guessed string is longer but still similar to the expected string.");
}

[Test]
public void CalculateEchoReward_ShouldHandleEmojisInStrings()
{
// Arrange
string expected = "hello👋<eos>";
string guessed = "hello👋 friend🙂<eos>";
TestContext.WriteLine($"Testing CalculateEchoReward with expected string containing an emoji: '{expected}' and guessed string with additional emoji: '{guessed}'.");

// Act
float reward = AgentUtils.CalculateEchoReward(expected, guessed);
TestContext.WriteLine($"Calculated reward: {reward}");

// Assert
// The expected reward should account for the emojis as part of the string.
// Update the expected value based on the logic of your CalculateEchoReward method.
// For example, if the expected reward is now 0.1f based on the new logic, update the test accordingly:
Assert.That(reward, Is.GreaterThan(0f), "The reward should be positive when the guessed string contains emojis and is similar to the expected string.");
TestContext.WriteLine($"Test passed: The calculated reward is positive as expected when the guessed string contains emojis and is similar to the expected string.");
}

}
}
2 changes: 0 additions & 2 deletions Tests/CommandArgTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,5 @@ public static void ParseCommand_EscapedBackslashes_HandlesCorrectly()
// Assert
Assert.That(result.String, Is.EqualTo(expected), "Escaped backslashes should be handled correctly.");
}


}
}

0 comments on commit 4fab0d5

Please sign in to comment.