From 560bfbd95f9c208dbbed8588a260a2abcdbc3992 Mon Sep 17 00:00:00 2001 From: Jarv <6jarv91@gmail.com> Date: Tue, 27 Feb 2024 22:23:08 +0100 Subject: [PATCH] Added serverDataStore Added setCustomTimeoutRole --- cmd/jarvbot/bunker.go | 16 ++++++++++++-- cmd/jarvbot/commands.go | 32 ++++++++++++++++++++++------ cmd/jarvbot/config.go | 5 +++-- cmd/jarvbot/db.go | 36 +++++++++++++++++++++++++++++++ cmd/jarvbot/db_test.go | 47 +++++++++++++++++++++++++++++++++++++++++ cmd/jarvbot/main.go | 5 +---- 6 files changed, 127 insertions(+), 14 deletions(-) create mode 100644 cmd/jarvbot/db_test.go diff --git a/cmd/jarvbot/bunker.go b/cmd/jarvbot/bunker.go index c383505..52ba787 100644 --- a/cmd/jarvbot/bunker.go +++ b/cmd/jarvbot/bunker.go @@ -17,7 +17,7 @@ func answerLiquid(ds *discordgo.Session, mc *discordgo.MessageCreate, ctx contex } func answerDon(ds *discordgo.Session, mc *discordgo.MessageCreate, ctx context.Context) bool { - timeoutRole, err := guildRoleByName(ds, mc.GuildID, timeoutRoleName) + timeoutRole, err := getTimeoutRole(ds, mc.GuildID) notifyIfErr("answerDon, couldn't get timeoutRole", err, ds) if err != nil { return false @@ -45,7 +45,7 @@ func answerShoot(ds *discordgo.Session, mc *discordgo.MessageCreate, ctx context return false } - timeoutRole, err := guildRoleByName(ds, mc.GuildID, timeoutRoleName) + timeoutRole, err := getTimeoutRole(ds, mc.GuildID) notifyIfErr("answerShoot: get timeout role", err, ds) if err != nil { return false @@ -68,6 +68,18 @@ func answerShoot(ds *discordgo.Session, mc *discordgo.MessageCreate, ctx context return err == nil } +func getTimeoutRole(ds *discordgo.Session, guildID string) (*discordgo.Role, error) { + customRoleName, err := serverDS.getServerProperty(guildID, customTimeoutRoleNameKey) + if err != nil { + customRoleName = defaultTimeoutRoleName + } + return guildRoleByName(ds, guildID, customRoleName) +} + +func setCustomTimeoutRole(ds *discordgo.Session, guildID string, roleName string) error { + return serverDS.setServerProperty(guildID, customTimeoutRoleNameKey, roleName) +} + // Internal functions func shoot(ds *discordgo.Session, channelID string, guildID string, shooter *discordgo.Member, target *discordgo.Member, timeoutRoleID string) error { diff --git a/cmd/jarvbot/commands.go b/cmd/jarvbot/commands.go index b147207..29f467e 100644 --- a/cmd/jarvbot/commands.go +++ b/cmd/jarvbot/commands.go @@ -43,6 +43,7 @@ func onMessageCreated(ctx context.Context) func(ds *discordgo.Session, mc *disco // the command key must be lowercased var commands = map[string]command{ // public + "!version": simpleTextResponse("v3.0.3"), "!source": simpleTextResponse("Source code: https://github.com/j4rv/discord-bot"), "!genshindailycheckin": answerGenshinDailyCheckIn, "!genshindailycheckinstop": answerGenshinDailyCheckInStop, @@ -63,12 +64,13 @@ var commands = map[string]command{ "!shoot": notSpammable(answerShoot), "!pp": notSpammable(answerPP), // only available for discord mods - "!roleids": modOnly(answerRoleIDs), - "!react4roles": modOnly(answerMakeReact4RolesMsg), - "!addcommand": modOnly(answerAddCommand), - "!removecommand": modOnly(answerRemoveCommand), - "!allowspamming": modOnly(answerAllowSpamming), - "!preventspamming": modOnly(answerPreventSpamming), + "!roleids": modOnly(answerRoleIDs), + "!react4roles": modOnly(answerMakeReact4RolesMsg), + "!addcommand": modOnly(answerAddCommand), + "!removecommand": modOnly(answerRemoveCommand), + "!allowspamming": modOnly(answerAllowSpamming), + "!preventspamming": modOnly(answerPreventSpamming), + "!setcustomtimeoutrole": modOnly(answerSetCustomTimeoutRole), // only available for the bot owner "!addglobalcommand": adminOnly(answerAddGlobalCommand), "!removeglobalcommand": adminOnly(answerRemoveGlobalCommand), @@ -180,6 +182,24 @@ func answerPreventSpamming(ds *discordgo.Session, mc *discordgo.MessageCreate, c return err == nil } +func answerSetCustomTimeoutRole(ds *discordgo.Session, mc *discordgo.MessageCreate, ctx context.Context) bool { + guildID := mc.GuildID + + timeoutRoleName := strings.TrimSpace(commandPrefixRegex.ReplaceAllString(mc.Content, "")) + _, err := guildRoleByName(ds, guildID, timeoutRoleName) + if err != nil { + ds.ChannelMessageSend(mc.ChannelID, fmt.Sprintf("Could not find role '%s'", timeoutRoleName)) + return false + } + + err = setCustomTimeoutRole(ds, guildID, timeoutRoleName) + notifyIfErr("setCustomTimeoutRole", err, ds) + if err == nil { + ds.ChannelMessageSend(mc.ChannelID, fmt.Sprintf("Custom timeout role set to '%s'", timeoutRoleName)) + } + return err == nil +} + // ---------- Simple command stuff ---------- func answerAddCommand(ds *discordgo.Session, mc *discordgo.MessageCreate, ctx context.Context) bool { diff --git a/cmd/jarvbot/config.go b/cmd/jarvbot/config.go index 213e420..23abb0c 100644 --- a/cmd/jarvbot/config.go +++ b/cmd/jarvbot/config.go @@ -4,7 +4,7 @@ import "time" // TODO make this a config file -const dbFilename = "db.sqlite" +var dbFilename = "db.sqlite" var strongboxMinAmount = 1.0 var strongboxMaxAmount = 64.0 @@ -13,7 +13,8 @@ var warnMessageMaxLength = 320 const avatarTargetSize = "1024" -const timeoutRoleName = "Shadow Realm" +const customTimeoutRoleNameKey = "custom_timeout_role_name" +const defaultTimeoutRoleName = "Shadow Realm" const shootMisfireChance = 0.2 const nuclearCatastropheChance = 0.006 const timeoutDurationWhenShot = 4 * time.Minute diff --git a/cmd/jarvbot/db.go b/cmd/jarvbot/db.go index 7835530..940f19c 100644 --- a/cmd/jarvbot/db.go +++ b/cmd/jarvbot/db.go @@ -14,6 +14,7 @@ import ( var moddingDS moddingDataStore var genshinDS genshinDataStore var commandDS commandDataStore +var serverDS serverDataStore var errZeroRowsAffected = errors.New("zero rows were affected") @@ -25,6 +26,7 @@ func createTables(db *sqlx.DB) { createTableSpammableChannel(db) createTableUserWarning(db) createTableReact4RoleMessage(db) + createTableServerProperties(db) } func createTableDailyCheckInReminder(db *sqlx.DB) { @@ -95,6 +97,17 @@ func createTableReact4RoleMessage(db *sqlx.DB) { createIndex("React4RoleMessage", "MessageID", db) } +func createTableServerProperties(db *sqlx.DB) { + createTable("ServerProperties", []string{ + "ServerID VARCHAR(20) UNIQUE NOT NULL", + "PropertyName VARCHAR(32) NOT NULL", + "PropertyValue TEXT NOT NULL", + "CreatedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP", + "UNIQUE(ServerID, PropertyName)", + }, db) + createIndex("ServerProperties", "ServerID", db) +} + // commands type commandDataStore struct { @@ -266,6 +279,29 @@ func (s moddingDataStore) deleteReact4Roles(channelID, messageID string) error { return err } +// server + +type serverDataStore struct { + db *sqlx.DB +} + +func (s *serverDataStore) setServerProperty(serverID, propertyName, propertyValue string) error { + _, err := s.db.Exec(` + INSERT INTO ServerProperties (ServerID, PropertyName, PropertyValue) + VALUES (?, ?, ?) + ON CONFLICT(ServerID, PropertyName) + DO UPDATE SET PropertyValue = excluded.PropertyValue`, + serverID, propertyName, propertyValue) + return err +} + +func (s *serverDataStore) getServerProperty(serverID, propertyName string) (string, error) { + var propertyValue string + err := s.db.Get(&propertyValue, `SELECT PropertyValue FROM ServerProperties WHERE ServerID = ? AND PropertyName = ?`, + serverID, propertyName) + return propertyValue, err +} + // methods for repetitive stuff func createTable(table string, columns []string, db *sqlx.DB) { diff --git a/cmd/jarvbot/db_test.go b/cmd/jarvbot/db_test.go new file mode 100644 index 0000000..e264115 --- /dev/null +++ b/cmd/jarvbot/db_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "os" + "testing" +) + +func initTestDB() { + dbFilename = "test.sqlite" + if _, err := os.Stat(dbFilename); err == nil { + err = os.Remove(dbFilename) + if err != nil { + panic("Failed to delete the database file: " + err.Error()) + } + } + initDB() +} + +func TestServerProperties(t *testing.T) { + initTestDB() + + _, err := serverDS.getServerProperty("0000", "key") + if err == nil { + t.Error("Expected error, got nil") + } + + serverDS.setServerProperty("0000", "key", "value") + val, err := serverDS.getServerProperty("0000", "key") + if err != nil { + t.Error(err) + } + if val != "value" { + t.Errorf("Expected 'value', got '%s'", val) + } + + err = serverDS.setServerProperty("0000", "key", "value2") + if err != nil { + t.Error(err) + } + val, err = serverDS.getServerProperty("0000", "key") + if err != nil { + t.Error(err) + } + if val != "value2" { + t.Errorf("Expected 'value2', got '%s'", val) + } +} diff --git a/cmd/jarvbot/main.go b/cmd/jarvbot/main.go index 9378167..3ec7751 100644 --- a/cmd/jarvbot/main.go +++ b/cmd/jarvbot/main.go @@ -4,11 +4,9 @@ import ( "context" "flag" "log" - "math/rand" "os" "os/signal" "syscall" - "time" "github.com/bwmarrin/discordgo" "github.com/jmoiron/sqlx" @@ -23,8 +21,6 @@ var noSlashCommands bool const discordMaxMessageLength = 2000 func main() { - rand.Seed(time.Now().UTC().UnixNano()) - initFlags() initDB() ds := initDiscordSession() @@ -66,6 +62,7 @@ func initDB() { genshinDS = genshinDataStore{db} commandDS = commandDataStore{db} moddingDS = moddingDataStore{db} + serverDS = serverDataStore{db} } func initDiscordSession() *discordgo.Session {