diff --git a/internal/config/interface.go b/internal/config/interface.go index 823d7db..b776c43 100644 --- a/internal/config/interface.go +++ b/internal/config/interface.go @@ -1,8 +1,6 @@ package config import ( - "time" - "gorm.io/datatypes" ) @@ -25,11 +23,10 @@ type SSHConfig struct { // Config represents the data structure of our user provided json configuration type Config struct { - ID int - Name string - SSH SSHConfig - Targets []string `json:"targets"` - Loaded time.Time + ID int + Name string + SSH SSHConfig + CIDR string } // SSHConfigModel represents the ssh config stored in the database @@ -41,31 +38,28 @@ type SSHConfigModel struct { // ConfigModel represents the config stored in the database type ConfigModel struct { - ID int `gorm:"primaryKey"` - Name string `gorm:"uniqueIndex"` - SSH SSHConfigModel `gorm:"embedded"` - Targets datatypes.JSON - Loaded time.Time `gorm:"index:,sort:desc"` + ID int `gorm:"primaryKey"` + Name string `gorm:"uniqueIndex"` + SSH SSHConfigModel `gorm:"embedded"` + CIDR string `gorm:"column:cidr"` } // Repo interface representing access to stored configs type Repo interface { Get(id int) (*Config, error) GetAll() ([]*Config, error) + GetByCIDR(cidr string) (*Config, error) Create(conf *Config) (*Config, error) Update(conf *Config) (*Config, error) Delete(id int) error - SetLastLoaded(id int) error - LastLoaded() (*Config, error) } // Service interface for manipulating configurations type Service interface { Get(id int) (*Config, error) GetAll() ([]*Config, error) + GetByCIDR(cidr string) (*Config, error) Create(conf *Config) (*Config, error) Update(conf *Config) (*Config, error) Delete(id int) error - SetLastLoaded(id int) error - LastLoaded() (*Config, error) } diff --git a/internal/config/repo.go b/internal/config/repo.go index 229b35e..04cfef4 100644 --- a/internal/config/repo.go +++ b/internal/config/repo.go @@ -3,7 +3,6 @@ package config import ( "encoding/json" "errors" - "time" "github.com/robgonnella/ops/internal/exception" "gorm.io/datatypes" @@ -38,10 +37,6 @@ func (r *SqliteRepo) Get(id int) (*Config, error) { return nil, result.Error } - if result := r.db.Save(&confModel); result.Error != nil { - return nil, result.Error - } - return modelToConfig(&confModel) } @@ -122,24 +117,11 @@ func (r *SqliteRepo) Delete(id int) error { return r.db.Delete(&ConfigModel{ID: id}).Error } -// SetLastLoaded updates a configs "loaded" field to the current timestamp -func (r *SqliteRepo) SetLastLoaded(id int) error { - confModel := ConfigModel{ID: id} - - if result := r.db.First(&confModel); result.Error != nil { - return result.Error - } - - confModel.Loaded = time.Now() - - return r.db.Save(&confModel).Error -} - // LastLoaded returns the most recently loaded config -func (r *SqliteRepo) LastLoaded() (*Config, error) { +func (r *SqliteRepo) GetByCIDR(cidr string) (*Config, error) { confModel := ConfigModel{} - if result := r.db.Order("loaded desc").First(&confModel); result.Error != nil { + if result := r.db.First(&confModel, "cidr = ?", cidr); result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, exception.ErrRecordNotFound } @@ -158,12 +140,6 @@ func modelToConfig(model *ConfigModel) (*Config, error) { return nil, err } - targets := []string{} - - if err := json.Unmarshal([]byte(model.Targets.String()), &targets); err != nil { - return nil, err - } - return &Config{ ID: model.ID, Name: model.Name, @@ -172,8 +148,7 @@ func modelToConfig(model *ConfigModel) (*Config, error) { Identity: model.SSH.Identity, Overrides: overrides, }, - Targets: targets, - Loaded: model.Loaded, + CIDR: model.CIDR, }, nil } @@ -184,12 +159,6 @@ func configToModel(conf *Config) (*ConfigModel, error) { return nil, err } - targetsBytes, err := json.Marshal(conf.Targets) - - if err != nil { - return nil, err - } - return &ConfigModel{ ID: conf.ID, Name: conf.Name, @@ -198,6 +167,6 @@ func configToModel(conf *Config) (*ConfigModel, error) { Identity: conf.SSH.Identity, Overrides: datatypes.JSON(overridesBytes), }, - Targets: datatypes.JSON(targetsBytes), + CIDR: conf.CIDR, }, nil } diff --git a/internal/config/repo_test.go b/internal/config/repo_test.go index 5aec434..17d8bbb 100644 --- a/internal/config/repo_test.go +++ b/internal/config/repo_test.go @@ -64,7 +64,7 @@ func TestConfigSqliteRepo(t *testing.T) { }, }, }, - Targets: []string{"target"}, + CIDR: "172.2.2.1/32", } newConf, err := repo.Create(conf) @@ -84,7 +84,7 @@ func TestConfigSqliteRepo(t *testing.T) { Identity: newConf.SSH.Identity, Overrides: newConf.SSH.Overrides, }, - Targets: newConf.Targets, + CIDR: newConf.CIDR, } updatedConf, err := repo.Update(toUpdate) @@ -107,19 +107,19 @@ func TestConfigSqliteRepo(t *testing.T) { conf1 := &config.Config{ Name: "test2", SSH: config.SSHConfig{ - User: "test-user1", - Identity: "test-identity1", + User: "test-user2", + Identity: "test-identity2", }, - Targets: []string{"target1"}, + CIDR: "172.2.2.2/32", } conf2 := &config.Config{ Name: "test3", SSH: config.SSHConfig{ - User: "test-user2", - Identity: "test-identity2", + User: "test-user3", + Identity: "test-identity3", }, - Targets: []string{"target2"}, + CIDR: "172.2.2.3/32", } _, err := repo.Create(conf1) @@ -144,40 +144,36 @@ func TestConfigSqliteRepo(t *testing.T) { }) - t.Run("gets last loaded", func(st *testing.T) { + t.Run("gets by cidr", func(st *testing.T) { conf1 := &config.Config{ Name: "test4", SSH: config.SSHConfig{ - User: "test-user1", - Identity: "test-identity1", + User: "test-user4", + Identity: "test-identity4", }, - Targets: []string{"target1"}, + CIDR: "172.2.2.4/32", } conf2 := &config.Config{ Name: "test5", SSH: config.SSHConfig{ - User: "test-user2", - Identity: "test-identity2", + User: "test-user5", + Identity: "test-identity5", }, - Targets: []string{"target2"}, + CIDR: "172.2.2.5/32", } - newConf1, err := repo.Create(conf1) - - assert.NoError(st, err) - - _, err = repo.Create(conf2) + _, err := repo.Create(conf1) assert.NoError(st, err) - err = repo.SetLastLoaded(newConf1.ID) + newConf2, err := repo.Create(conf2) assert.NoError(st, err) - lastLoaded, err := repo.LastLoaded() + foundConf, err := repo.GetByCIDR("172.2.2.5/32") assert.NoError(st, err) - assertEqualConf(st, newConf1, lastLoaded) + assertEqualConf(st, newConf2, foundConf) }) } diff --git a/internal/config/service.go b/internal/config/service.go index 11e5a7d..32deafa 100644 --- a/internal/config/service.go +++ b/internal/config/service.go @@ -20,6 +20,10 @@ func (s *ConfigService) GetAll() ([]*Config, error) { return s.repo.GetAll() } +func (s *ConfigService) GetByCIDR(cidr string) (*Config, error) { + return s.repo.GetByCIDR(cidr) +} + // Create creates a new config func (s *ConfigService) Create(conf *Config) (*Config, error) { return s.repo.Create(conf) @@ -34,13 +38,3 @@ func (s *ConfigService) Update(conf *Config) (*Config, error) { func (s *ConfigService) Delete(id int) error { return s.repo.Delete(id) } - -// SetLasLoaded sets the "loaded" field for a config to current timestamp -func (s *ConfigService) SetLastLoaded(id int) error { - return s.repo.SetLastLoaded(id) -} - -// LastLoaded retrieves the most recently loaded config -func (s *ConfigService) LastLoaded() (*Config, error) { - return s.repo.LastLoaded() -} diff --git a/internal/config/service_test.go b/internal/config/service_test.go index 3fb2b6e..8e690b5 100644 --- a/internal/config/service_test.go +++ b/internal/config/service_test.go @@ -26,7 +26,7 @@ func TestConfigService(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: "172.2.2.2/32", } mockRepo.EXPECT().Get(expectedConfig.ID).Return(expectedConfig, nil) @@ -44,7 +44,7 @@ func TestConfigService(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: "172.2.2.2/32", } conf2 := &config.Config{ @@ -53,7 +53,7 @@ func TestConfigService(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: "172.2.2.3/32", } expectedConfs := []*config.Config{conf1, conf2} @@ -73,7 +73,7 @@ func TestConfigService(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: "172.2.2.2/32", } mockRepo.EXPECT().Create(conf).Return(conf, nil) @@ -91,7 +91,7 @@ func TestConfigService(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: "172.2.2.2/32", } mockRepo.EXPECT().Update(conf).Return(conf, nil) @@ -112,19 +112,21 @@ func TestConfigService(t *testing.T) { assert.NoError(st, err) }) - t.Run("gets last loaded config", func(st *testing.T) { + t.Run("gets last config by cidr", func(st *testing.T) { + cidr := "172.2.2.2/32" + expectedConfig := &config.Config{ Name: "test", SSH: config.SSHConfig{ User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: cidr, } - mockRepo.EXPECT().LastLoaded().Return(expectedConfig, nil) + mockRepo.EXPECT().GetByCIDR(cidr).Return(expectedConfig, nil) - foundConf, err := service.LastLoaded() + foundConf, err := service.GetByCIDR(cidr) assert.NoError(st, err) assert.Equal(st, expectedConfig, foundConf) diff --git a/internal/core/core.go b/internal/core/core.go index 74171d9..e14f0cd 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -111,10 +111,6 @@ func (c *Core) UpdateConfig(conf config.Config) error { return err } - if err := c.configService.SetLastLoaded(updated.ID); err != nil { - return err - } - c.conf = updated return nil @@ -128,10 +124,6 @@ func (c *Core) SetConfig(id int) error { return err } - if err := c.configService.SetLastLoaded(conf.ID); err != nil { - return err - } - c.conf = conf return nil diff --git a/internal/core/core_test.go b/internal/core/core_test.go index 0ad9c61..f17366b 100644 --- a/internal/core/core_test.go +++ b/internal/core/core_test.go @@ -52,7 +52,7 @@ func TestCore(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"172.100.1.1/24"}, + CIDR: "172.100.1.1/24", } coreService := core.New( @@ -79,13 +79,11 @@ func TestCore(t *testing.T) { User: "new-user", Identity: "new-identity", }, - Targets: []string{"new-target"}, + CIDR: "192.111.1.1/28", } mockConfig.EXPECT().Update(&newConf).Return(&newConf, nil) mockConfig.EXPECT().Update(&conf).Return(&conf, nil) - mockConfig.EXPECT().SetLastLoaded(newConf.ID) - mockConfig.EXPECT().SetLastLoaded(conf.ID) err := coreService.UpdateConfig(newConf) @@ -103,13 +101,11 @@ func TestCore(t *testing.T) { User: "other-user", Identity: "other-identity", }, - Targets: []string{"other target"}, + CIDR: "172.22.2.2/32", } mockConfig.EXPECT().Get(anotherConf.ID).Return(&anotherConf, nil) mockConfig.EXPECT().Get(conf.ID).Return(&conf, nil) - mockConfig.EXPECT().SetLastLoaded(anotherConf.ID) - mockConfig.EXPECT().SetLastLoaded(conf.ID) err := coreService.SetConfig(anotherConf.ID) @@ -124,7 +120,7 @@ func TestCore(t *testing.T) { User: "new-user", Identity: "new-identity", }, - Targets: []string{"new-target"}, + CIDR: "172.22.2.2/32", } mockConfig.EXPECT().Create(&newConf).Return(&newConf, nil) @@ -151,7 +147,7 @@ func TestCore(t *testing.T) { User: "other-user", Identity: "other-identity", }, - Targets: []string{"other target"}, + CIDR: "172.22.2.3/32", } expectedConfs := []*config.Config{&conf, &anotherConf} @@ -227,8 +223,8 @@ func TestCore(t *testing.T) { mockServerService.EXPECT().StreamEvents(gomock.Any()).Return(1) mockServerService.EXPECT(). - GetAllServersInNetworkTargets(conf.Targets). - Do(func([]string) { + GetAllServersInNetwork(conf.CIDR). + Do(func(string) { wg.Done() }) mockScanner.EXPECT().Scan().DoAndReturn(func() error { diff --git a/internal/core/create.go b/internal/core/create.go index 9fa0dd1..48e59e3 100644 --- a/internal/core/create.go +++ b/internal/core/create.go @@ -48,7 +48,7 @@ func getDefaultConfig(networkInfo *util.NetworkInfo) *config.Config { Identity: identity, Overrides: []config.SSHOverride{}, }, - Targets: []string{networkInfo.Cidr}, + CIDR: networkInfo.Cidr, } } @@ -65,7 +65,7 @@ func CreateNewAppCore(networkInfo *util.NetworkInfo) (*Core, error) { configRepo := config.NewSqliteRepo(db) configService := config.NewConfigService(configRepo) - conf, err := configService.LastLoaded() + conf, err := configService.GetByCIDR(networkInfo.Cidr) if err != nil { if errors.Is(err, exception.ErrRecordNotFound) { @@ -87,7 +87,6 @@ func CreateNewAppCore(networkInfo *util.NetworkInfo) (*Core, error) { netScanner, err := discovery.NewARPScanner( networkInfo, - conf.Targets, resultChan, ) diff --git a/internal/core/monitor.go b/internal/core/monitor.go index b6fdc7d..2818aa5 100644 --- a/internal/core/monitor.go +++ b/internal/core/monitor.go @@ -76,8 +76,8 @@ func (c *Core) pollForDatabaseUpdates() error { return fmt.Errorf("too many consecutive errors encountered") } - response, err := c.serverService.GetAllServersInNetworkTargets( - c.conf.Targets, + response, err := c.serverService.GetAllServersInNetwork( + c.conf.CIDR, ) if err != nil { diff --git a/internal/discovery/arpscan.go b/internal/discovery/arpscan.go index 5d9b912..527cbc8 100644 --- a/internal/discovery/arpscan.go +++ b/internal/discovery/arpscan.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "net" - "regexp" "sync" "github.com/google/gopacket" @@ -17,8 +16,6 @@ import ( "github.com/rs/zerolog/log" ) -var cidrSuffix = regexp.MustCompile(`\/\d{2}$`) - type ARPScanner struct { ctx context.Context cancel context.CancelFunc @@ -34,25 +31,18 @@ type ARPScanner struct { func NewARPScanner( networkInfo *util.NetworkInfo, - targets []string, resultChan chan *DiscoveryResult, ) (*ARPScanner, error) { ipList := []string{} - for _, t := range targets { - if cidrSuffix.MatchString(t) { - ips, err := mapcidr.IPAddresses(t) - - if err != nil { - return nil, err - } + ips, err := mapcidr.IPAddresses(networkInfo.Cidr) - ipList = append(ipList, ips...) - } else { - ipList = append(ipList, t) - } + if err != nil { + return nil, err } + ipList = append(ipList, ips...) + // Open up a pcap handle for packet reads/writes. handle, err := pcap.OpenLive( networkInfo.Interface.Name, diff --git a/internal/mock/config/mock_config.go b/internal/mock/config/mock_config.go index c625f71..13233ae 100644 --- a/internal/mock/config/mock_config.go +++ b/internal/mock/config/mock_config.go @@ -93,33 +93,19 @@ func (mr *MockRepoMockRecorder) GetAll() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAll", reflect.TypeOf((*MockRepo)(nil).GetAll)) } -// LastLoaded mocks base method. -func (m *MockRepo) LastLoaded() (*config.Config, error) { +// GetByCIDR mocks base method. +func (m *MockRepo) GetByCIDR(arg0 string) (*config.Config, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LastLoaded") + ret := m.ctrl.Call(m, "GetByCIDR", arg0) ret0, _ := ret[0].(*config.Config) ret1, _ := ret[1].(error) return ret0, ret1 } -// LastLoaded indicates an expected call of LastLoaded. -func (mr *MockRepoMockRecorder) LastLoaded() *gomock.Call { +// GetByCIDR indicates an expected call of GetByCIDR. +func (mr *MockRepoMockRecorder) GetByCIDR(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastLoaded", reflect.TypeOf((*MockRepo)(nil).LastLoaded)) -} - -// SetLastLoaded mocks base method. -func (m *MockRepo) SetLastLoaded(arg0 int) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetLastLoaded", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetLastLoaded indicates an expected call of SetLastLoaded. -func (mr *MockRepoMockRecorder) SetLastLoaded(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLastLoaded", reflect.TypeOf((*MockRepo)(nil).SetLastLoaded), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByCIDR", reflect.TypeOf((*MockRepo)(nil).GetByCIDR), arg0) } // Update mocks base method. @@ -219,33 +205,19 @@ func (mr *MockServiceMockRecorder) GetAll() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAll", reflect.TypeOf((*MockService)(nil).GetAll)) } -// LastLoaded mocks base method. -func (m *MockService) LastLoaded() (*config.Config, error) { +// GetByCIDR mocks base method. +func (m *MockService) GetByCIDR(arg0 string) (*config.Config, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "LastLoaded") + ret := m.ctrl.Call(m, "GetByCIDR", arg0) ret0, _ := ret[0].(*config.Config) ret1, _ := ret[1].(error) return ret0, ret1 } -// LastLoaded indicates an expected call of LastLoaded. -func (mr *MockServiceMockRecorder) LastLoaded() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastLoaded", reflect.TypeOf((*MockService)(nil).LastLoaded)) -} - -// SetLastLoaded mocks base method. -func (m *MockService) SetLastLoaded(arg0 int) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetLastLoaded", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetLastLoaded indicates an expected call of SetLastLoaded. -func (mr *MockServiceMockRecorder) SetLastLoaded(arg0 interface{}) *gomock.Call { +// GetByCIDR indicates an expected call of GetByCIDR. +func (mr *MockServiceMockRecorder) GetByCIDR(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetLastLoaded", reflect.TypeOf((*MockService)(nil).SetLastLoaded), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByCIDR", reflect.TypeOf((*MockService)(nil).GetByCIDR), arg0) } // Update mocks base method. diff --git a/internal/mock/server/mock_server.go b/internal/mock/server/mock_server.go index 1174611..53f28d0 100644 --- a/internal/mock/server/mock_server.go +++ b/internal/mock/server/mock_server.go @@ -176,19 +176,19 @@ func (mr *MockServiceMockRecorder) GetAllServers() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServers", reflect.TypeOf((*MockService)(nil).GetAllServers)) } -// GetAllServersInNetworkTargets mocks base method. -func (m *MockService) GetAllServersInNetworkTargets(arg0 []string) ([]*server.Server, error) { +// GetAllServersInNetwork mocks base method. +func (m *MockService) GetAllServersInNetwork(arg0 string) ([]*server.Server, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAllServersInNetworkTargets", arg0) + ret := m.ctrl.Call(m, "GetAllServersInNetwork", arg0) ret0, _ := ret[0].([]*server.Server) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetAllServersInNetworkTargets indicates an expected call of GetAllServersInNetworkTargets. -func (mr *MockServiceMockRecorder) GetAllServersInNetworkTargets(arg0 interface{}) *gomock.Call { +// GetAllServersInNetwork indicates an expected call of GetAllServersInNetwork. +func (mr *MockServiceMockRecorder) GetAllServersInNetwork(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServersInNetworkTargets", reflect.TypeOf((*MockService)(nil).GetAllServersInNetworkTargets), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllServersInNetwork", reflect.TypeOf((*MockService)(nil).GetAllServersInNetwork), arg0) } // GetServer mocks base method. diff --git a/internal/server/interface.go b/internal/server/interface.go index 187b2eb..ce008f8 100644 --- a/internal/server/interface.go +++ b/internal/server/interface.go @@ -46,7 +46,7 @@ type Repo interface { // Service interface for server related logic type Service interface { GetAllServers() ([]*Server, error) - GetAllServersInNetworkTargets(targets []string) ([]*Server, error) + GetAllServersInNetwork(cidr string) ([]*Server, error) AddOrUpdateServer(req *Server) error MarkServerOffline(ip string) error StreamEvents(send chan *event.Event) int diff --git a/internal/server/service.go b/internal/server/service.go index fff1b29..1cb24b3 100644 --- a/internal/server/service.go +++ b/internal/server/service.go @@ -2,6 +2,7 @@ package server import ( "errors" + "fmt" "net" "sync" @@ -63,9 +64,9 @@ func (s *ServerService) GetAllServers() ([]*Server, error) { return s.repo.GetAllServers() } -// GetAllServersInNetworkTargets returns all servers in database that have ips +// GetAllServersInNetwork returns all servers in database that have ips // within the provided list of network targets -func (s *ServerService) GetAllServersInNetworkTargets(targets []string) ([]*Server, error) { +func (s *ServerService) GetAllServersInNetwork(cidr string) ([]*Server, error) { allServers, err := s.GetAllServers() result := []*Server{} @@ -74,33 +75,25 @@ func (s *ServerService) GetAllServersInNetworkTargets(targets []string) ([]*Serv return nil, err } + _, ipnet, err := net.ParseCIDR(cidr) + + if err != nil { + return nil, err + } + + if ipnet == nil { + return nil, fmt.Errorf("failed to parse cidr: %s", cidr) + } + for _, server := range allServers { - for _, target := range targets { - _, ipnet, err := net.ParseCIDR(target) - - if err != nil { - // non CIDR target just check if target matches IP - if server.IP == target { - result = append(result, server) - break - } - - // target is not a cidr and does not match server ip - // just continue looping targets - continue - } - - svrNetIP := net.ParseIP(server.IP) - - if ipnet != nil && ipnet.Contains(svrNetIP) { - // server IP is within target CIDR block - result = append(result, server) - break - } + svrNetIP := net.ParseIP(server.IP) + + if ipnet.Contains(svrNetIP) { + // server IP is within target CIDR block + result = append(result, server) } - } - // s.log.Error().Interface("result", result).Msg("returning all servers in target") + } return result, nil } diff --git a/internal/server/service_test.go b/internal/server/service_test.go index 99c741d..0edbde4 100644 --- a/internal/server/service_test.go +++ b/internal/server/service_test.go @@ -25,7 +25,7 @@ func TestServerService(t *testing.T) { User: "user", Identity: "identity", }, - Targets: []string{"target"}, + CIDR: "172.22.2.2/32", } service := server.NewService(conf, mockRepo) @@ -50,9 +50,7 @@ func TestServerService(t *testing.T) { assert.Equal(st, expectedServers, foundServers) }) - t.Run("gets all servers in network targets", func(st *testing.T) { - targets := []string{"192.168.1.10", "172.16.1.1/24"} - + t.Run("gets all servers in network", func(st *testing.T) { testServer1 := *testServer testServer2 := *testServer testServer3 := *testServer @@ -69,12 +67,12 @@ func TestServerService(t *testing.T) { expectedServers := []*server.Server{ &testServer1, - &testServer2, + &testServer3, } mockRepo.EXPECT().GetAllServers().Return(testServers, nil) - foundServers, err := service.GetAllServersInNetworkTargets(targets) + foundServers, err := service.GetAllServersInNetwork("192.168.1.1/24") assert.NoError(st, err) assert.Equal(st, 2, len(foundServers)) diff --git a/internal/ui/component/configure.go b/internal/ui/component/configure.go index c87a9cc..5af2df8 100644 --- a/internal/ui/component/configure.go +++ b/internal/ui/component/configure.go @@ -1,8 +1,6 @@ package component import ( - "strings" - "github.com/gdamore/tcell/v2" "github.com/rivo/tview" "github.com/robgonnella/ops/internal/config" @@ -39,7 +37,8 @@ func addBlankFormItems( sshIdentityInput.SetLabel("SSH Identity: ") cidrInput := tview.NewInputField() - cidrInput.SetLabel("Comma Separated CIDRs or IPs: ") + cidrInput.SetLabel("Network CIDR (auto-configured): ") + cidrInput.SetDisabled(true) form.AddFormItem(configName) form.AddFormItem(cidrInput) @@ -61,15 +60,17 @@ func addBlankFormItems( } // every time the add ssh override button is clicked we add three new inputs -func createOverrideInputs() (*tview.InputField, *tview.InputField, *tview.InputField) { +func createOverrideInputs(conf config.Config) (*tview.InputField, *tview.InputField, *tview.InputField) { overrideTarget := tview.NewInputField() - overrideTarget.SetLabel("Override Target: ") + overrideTarget.SetLabel("Override Target IP: ") overrideSSHUser := tview.NewInputField() overrideSSHUser.SetLabel("Override SSH User: ") + overrideSSHUser.SetText(conf.SSH.User) overrideSSHIdentity := tview.NewInputField() overrideSSHIdentity.SetLabel("Override SSH Identity: ") + overrideSSHIdentity.SetText(conf.SSH.Identity) return overrideTarget, overrideSSHUser, overrideSSHIdentity } @@ -117,7 +118,7 @@ func (f *ConfigureForm) render() { f.configName, f.sshUserInput, f.sshIdentityInput, f.cidrInput = addBlankFormItems(f.root, f.conf.Name) - networkTargets := strings.Join(f.conf.Targets, ",") + networkTargets := f.conf.CIDR f.configName.SetText(f.conf.Name) f.sshUserInput.SetText(f.conf.SSH.User) @@ -125,7 +126,7 @@ func (f *ConfigureForm) render() { f.cidrInput.SetText(networkTargets) for _, o := range f.conf.SSH.Overrides { - target, user, identity := createOverrideInputs() + target, user, identity := createOverrideInputs(f.conf) f.overrides = append(f.overrides, map[string]*tview.InputField{ "target": target, @@ -156,7 +157,7 @@ func (f *ConfigureForm) addFormButtons() { }) f.root.AddButton("Add SSH Override", func() { - target, user, identity := createOverrideInputs() + target, user, identity := createOverrideInputs(f.conf) f.overrides = append(f.overrides, map[string]*tview.InputField{ "target": target, @@ -195,7 +196,6 @@ func (f *ConfigureForm) addFormButtons() { return } - targets := strings.Split(cidr, ",") confOverrides := []config.SSHOverride{} for _, o := range f.overrides { @@ -215,7 +215,7 @@ func (f *ConfigureForm) addFormButtons() { Identity: sshIdentity, Overrides: confOverrides, }, - Targets: targets, + CIDR: cidr, } if f.creatingNewConfig { diff --git a/internal/ui/component/context.go b/internal/ui/component/context.go index 6c73faf..8e53029 100644 --- a/internal/ui/component/context.go +++ b/internal/ui/component/context.go @@ -2,7 +2,6 @@ package component import ( "strconv" - "strings" "github.com/gdamore/tcell/v2" "github.com/rivo/tview" @@ -26,7 +25,7 @@ func NewConfigContext( ) *ConfigContext { log := logger.New() - colHeaders := []string{"ID", "Name", "Target", "SSH-User", "SSH-Identity", "Overrides"} + colHeaders := []string{"ID", "Name", "CIDR", "SSH-User", "SSH-Identity", "Overrides"} table := createTable("Context", colHeaders) table.SetInputCapture(func(evt *tcell.EventKey) *tcell.EventKey { @@ -77,7 +76,7 @@ func (c *ConfigContext) UpdateConfigs(current int, confs []*config.Config) { id := conf.ID idStr := strconv.Itoa(id) name := conf.Name - target := strings.Join(conf.Targets, ",") + cidr := conf.CIDR sshUser := conf.SSH.User sshIdentity := conf.SSH.Identity overrides := "N" @@ -86,7 +85,7 @@ func (c *ConfigContext) UpdateConfigs(current int, confs []*config.Config) { overrides = "Y" } - row := []string{idStr, name, target, sshUser, sshIdentity, overrides} + row := []string{idStr, name, cidr, sshUser, sshIdentity, overrides} for col, text := range row { if id == current && col == 1 { diff --git a/internal/ui/component/header.go b/internal/ui/component/header.go index d6d30c1..2acbb5e 100644 --- a/internal/ui/component/header.go +++ b/internal/ui/component/header.go @@ -2,7 +2,6 @@ package component import ( "fmt" - "strings" "github.com/rivo/tview" "github.com/robgonnella/ops/internal/ui/style" @@ -23,14 +22,14 @@ type Header struct { legendCol1 *tview.Flex legendCol2 *tview.Flex switchViewInput *SwitchViewInput - targets []string + cidr string extraLegendMap map[string]tview.Primitive } // NewHeader returns a new instance of Header func NewHeader( userIP string, - targets []string, + cidr string, onViewSwitch func(text string), ) *Header { h := &Header{} @@ -70,9 +69,9 @@ func NewHeader( currentTarget := tview.NewTextView(). SetText( fmt.Sprintf( - "IP: %s, Network Targets: %s", + "IP: %s, Network Target: %s", userIP, - strings.Join(targets, ","), + cidr, ), ) @@ -82,7 +81,7 @@ func NewHeader( h.root.AddItem(currentTarget, 1, 1, false) h.root.AddItem(h.switchViewInput.Primitive(), 3, 1, false) - h.targets = targets + h.cidr = cidr h.extraLegendMap = map[string]tview.Primitive{} diff --git a/internal/ui/view.go b/internal/ui/view.go index 2858ec4..a98c431 100644 --- a/internal/ui/view.go +++ b/internal/ui/view.go @@ -92,7 +92,7 @@ func (v *view) initialize( v.header = component.NewHeader( netInfo.UserIP.String(), - v.appCore.Conf().Targets, + v.appCore.Conf().CIDR, v.onActionSubmit, ) v.serverTable = component.NewServerTable(