From 350113202a10734513c1d2dc4dd10d3cbb9a5efe Mon Sep 17 00:00:00 2001 From: Ayrat Hudaygulov Date: Wed, 7 Aug 2024 22:24:15 +0100 Subject: [PATCH] added random sort of training data for better training results (#48) --- .env.example | 1 + src/VahterBanBot.Tests/ContainerTestBase.fs | 1 + src/VahterBanBot/ML.fs | 6 +++++- src/VahterBanBot/Program.fs | 1 + src/VahterBanBot/Types.fs | 1 + 5 files changed, 9 insertions(+), 1 deletion(-) diff --git a/.env.example b/.env.example index 25df51d..42a664b 100644 --- a/.env.example +++ b/.env.example @@ -24,6 +24,7 @@ ML_SPAM_DELETION_ENABLED=false ML_SPAM_AUTOBAN_ENABLED=true ML_SPAM_AUTOBAN_SCORE_THRESHOLD=-5.0 ML_SPAM_AUTOBAN_CHECK_LAST_MSG_COUNT=10 +ML_TRAIN_RANDOM_SORT_DATA=true ML_TRAIN_INTERVAL_DAYS=30 ML_TRAIN_CRITICAL_MSG_COUNT=5 ML_TRAINING_SET_FRACTION=0.2 diff --git a/src/VahterBanBot.Tests/ContainerTestBase.fs b/src/VahterBanBot.Tests/ContainerTestBase.fs index 0f03b57..1f9106a 100644 --- a/src/VahterBanBot.Tests/ContainerTestBase.fs +++ b/src/VahterBanBot.Tests/ContainerTestBase.fs @@ -96,6 +96,7 @@ type VahterTestContainers() = .WithEnvironment("CLEANUP_OLD_MESSAGES", "false") .WithEnvironment("ML_ENABLED", "true") .WithEnvironment("ML_SEED", "42") + .WithEnvironment("ML_TRAIN_RANDOM_SORT_DATA", "false") .WithEnvironment("ML_SPAM_DELETION_ENABLED", "true") .WithEnvironment("ML_SPAM_THRESHOLD", "1.0") .WithEnvironment("ML_STOP_WORDS_IN_CHATS", """{"-42":["2"]}""") diff --git a/src/VahterBanBot/ML.fs b/src/VahterBanBot/ML.fs index 86fc7f5..fe9b3d2 100644 --- a/src/VahterBanBot/ML.fs +++ b/src/VahterBanBot/ML.fs @@ -75,7 +75,11 @@ type MachineLearning( createdAt = x.created_at lessThanNMessagesF = if x.less_than_n_messages then 1.0f else 0.0f } ) - + |> fun x -> + if botConf.MlTrainRandomSortData then + Array.sortInPlaceBy (fun _ -> Guid.NewGuid()) x + x + let dataView = mlContext.Data.LoadFromEnumerable data let trainTestSplit = mlContext.Data.TrainTestSplit(dataView, testFraction = botConf.MlTrainingSetFraction) let trainingData = trainTestSplit.TrainSet diff --git a/src/VahterBanBot/Program.fs b/src/VahterBanBot/Program.fs index d0bb972..5e093ed 100644 --- a/src/VahterBanBot/Program.fs +++ b/src/VahterBanBot/Program.fs @@ -61,6 +61,7 @@ let botConf = MlSpamAutobanEnabled = getEnvOr "ML_SPAM_AUTOBAN_ENABLED" "false" |> bool.Parse MlSpamAutobanCheckLastMsgCount = getEnvOr "ML_SPAM_AUTOBAN_CHECK_LAST_MSG_COUNT" "10" |> int MlSpamAutobanScoreThreshold = getEnvOr "ML_SPAM_AUTOBAN_SCORE_THRESHOLD" "-5.0" |> double + MlTrainRandomSortData = getEnvOr "ML_TRAIN_RANDOM_SORT_DATA" "true" |> bool.Parse MlTrainInterval = getEnvOr "ML_TRAIN_INTERVAL_DAYS" "30" |> int |> TimeSpan.FromDays MlTrainCriticalMsgCount = getEnvOr "ML_TRAIN_CRITICAL_MSG_COUNT" "5" |> int MlTrainingSetFraction = getEnvOr "ML_TRAINING_SET_FRACTION" "0.2" |> float diff --git a/src/VahterBanBot/Types.fs b/src/VahterBanBot/Types.fs index bd5f8dc..788e70d 100644 --- a/src/VahterBanBot/Types.fs +++ b/src/VahterBanBot/Types.fs @@ -34,6 +34,7 @@ type BotConfiguration = MlSpamAutobanEnabled: bool MlSpamAutobanCheckLastMsgCount: int MlSpamAutobanScoreThreshold: double + MlTrainRandomSortData: bool MlTrainInterval: TimeSpan MlTrainCriticalMsgCount: int MlTrainingSetFraction: float