diff --git a/internal/config/interface.go b/internal/config/interface.go index a9d69bb..ab5100b 100644 --- a/internal/config/interface.go +++ b/internal/config/interface.go @@ -12,12 +12,14 @@ type SSHOverride struct { Target string `json:"target"` User string `json:"user"` Identity string `json:"identity"` + Port string `json:"port"` } // SSHConfig represents the config needed to ssh to servers type SSHConfig struct { User string `json:"user"` Identity string `json:"identity"` + Port string `json:"port"` Overrides []SSHOverride `json:"overrides"` } diff --git a/internal/config/repo-sqlite.go b/internal/config/repo-sqlite.go deleted file mode 100644 index bbe971f..0000000 --- a/internal/config/repo-sqlite.go +++ /dev/null @@ -1,175 +0,0 @@ -package config - -import ( - "encoding/json" - "errors" - - "github.com/google/uuid" - "github.com/robgonnella/ops/internal/exception" - "gorm.io/datatypes" - "gorm.io/gorm" -) - -// SqliteRepo is our repo implementation for sqlite -type SqliteRepo struct { - db *gorm.DB -} - -// NewSqliteRepo returns a new ops sqlite db -func NewSqliteRepo(db *gorm.DB) *SqliteRepo { - return &SqliteRepo{ - db: db, - } -} - -// Get returns a config from the db -func (r *SqliteRepo) Get(id string) (*Config, error) { - if id == "" { - return nil, errors.New("config id cannot be empty") - } - - confModel := ConfigModel{ID: id} - - if result := r.db.First(&confModel); result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, exception.ErrRecordNotFound - } - - return nil, result.Error - } - - return modelToConfig(&confModel) -} - -// GetAll returns all configs in db -func (r *SqliteRepo) GetAll() ([]*Config, error) { - confModels := []ConfigModel{} - - if result := r.db.Find(&confModels); result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, exception.ErrRecordNotFound - } - - return nil, result.Error - } - - confs := []*Config{} - - for _, m := range confModels { - c, err := modelToConfig(&m) - - if err != nil { - return nil, err - } - - confs = append(confs, c) - } - - return confs, nil -} - -// Create creates a new config in db -func (r *SqliteRepo) Create(conf *Config) (*Config, error) { - if conf.Name == "" { - return nil, errors.New("config name cannot be empty") - } - - conf.ID = uuid.New().String() - - confModel, err := configToModel(conf) - - if err != nil { - return nil, err - } - - // create or update - result := r.db.Create(confModel) - - if result.Error != nil { - return nil, result.Error - } - - return modelToConfig(confModel) -} - -// Update updates a config in db -func (r *SqliteRepo) Update(conf *Config) (*Config, error) { - if conf.ID == "" { - return nil, errors.New("config ID cannot be empty") - } - - confModel, err := configToModel(conf) - - if err != nil { - return nil, err - } - - if result := r.db.Save(confModel); result.Error != nil { - return nil, result.Error - } - - return modelToConfig(confModel) -} - -// Delete deletes a config from db -func (r *SqliteRepo) Delete(id string) error { - if id == "" { - return errors.New("config id cannot be empty") - } - - return r.db.Delete(&ConfigModel{ID: id}).Error -} - -// LastLoaded returns the most recently loaded config -func (r *SqliteRepo) GetByCIDR(cidr string) (*Config, error) { - confModel := ConfigModel{} - - if result := r.db.First(&confModel, "cidr = ?", cidr); result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, exception.ErrRecordNotFound - } - - return nil, result.Error - } - - return modelToConfig(&confModel) -} - -// helpers -func modelToConfig(model *ConfigModel) (*Config, error) { - overrides := []SSHOverride{} - - if err := json.Unmarshal([]byte(model.SSH.Overrides.String()), &overrides); err != nil { - return nil, err - } - - return &Config{ - ID: model.ID, - Name: model.Name, - SSH: SSHConfig{ - User: model.SSH.User, - Identity: model.SSH.Identity, - Overrides: overrides, - }, - CIDR: model.CIDR, - }, nil -} - -func configToModel(conf *Config) (*ConfigModel, error) { - overridesBytes, err := json.Marshal(conf.SSH.Overrides) - - if err != nil { - return nil, err - } - - return &ConfigModel{ - ID: conf.ID, - Name: conf.Name, - SSH: SSHConfigModel{ - User: conf.SSH.User, - Identity: conf.SSH.Identity, - Overrides: datatypes.JSON(overridesBytes), - }, - CIDR: conf.CIDR, - }, nil -} diff --git a/internal/config/repo-yaml.go b/internal/config/repo-yaml.go index e8ed825..5234f6e 100644 --- a/internal/config/repo-yaml.go +++ b/internal/config/repo-yaml.go @@ -11,7 +11,6 @@ import ( "github.com/google/uuid" "github.com/robgonnella/ops/internal/exception" - "github.com/spf13/viper" ) // JSONRepo is our repo implementation for json @@ -22,9 +21,7 @@ type JSONRepo struct { } // NewJSONRepo returns a new ops repo for flat yaml file -func NewJSONRepo() *JSONRepo { - configPath := viper.Get("config-path").(string) - +func NewJSONRepo(configPath string) *JSONRepo { repo := &JSONRepo{ configPath: configPath, configs: []*Config{}, @@ -226,6 +223,7 @@ func copyConfig(c *Config) *Config { SSH: SSHConfig{ User: c.SSH.User, Identity: c.SSH.Identity, + Port: c.SSH.Port, Overrides: c.SSH.Overrides, }, CIDR: c.CIDR, diff --git a/internal/config/repo_test.go b/internal/config/repo_test.go index 4af8ff2..eac9090 100644 --- a/internal/config/repo_test.go +++ b/internal/config/repo_test.go @@ -6,7 +6,6 @@ import ( "github.com/robgonnella/ops/internal/config" "github.com/robgonnella/ops/internal/exception" - "github.com/robgonnella/ops/internal/test_util" "github.com/stretchr/testify/assert" ) @@ -22,26 +21,14 @@ func assertEqualConf(t *testing.T, expected, actual *config.Config) { } } -func TestConfigSqliteRepo(t *testing.T) { - testDBFile := "config.db" +func TestConfigYamlRepo(t *testing.T) { + testConfigFile := "config.json" defer func() { - os.RemoveAll(testDBFile) + os.RemoveAll(testConfigFile) }() - db, err := test_util.GetDBConnection(testDBFile) - - if err != nil { - t.Logf("failed to create test db: %s", err.Error()) - t.FailNow() - } - - if err := test_util.Migrate(db, config.ConfigModel{}); err != nil { - t.Logf("failed to migrate test db: %s", err.Error()) - t.FailNow() - } - - repo := config.NewSqliteRepo(db) + repo := config.NewJSONRepo(testConfigFile) t.Run("returns record not found error", func(st *testing.T) { _, err := repo.Get("10") diff --git a/internal/core/create.go b/internal/core/create.go index d83658d..44e483e 100644 --- a/internal/core/create.go +++ b/internal/core/create.go @@ -27,6 +27,7 @@ func getDefaultConfig(networkInfo *network.NetworkInfo) *config.Config { SSH: config.SSHConfig{ User: user, Identity: identity, + Port: "22", Overrides: []config.SSHOverride{}, }, CIDR: networkInfo.Cidr, @@ -35,7 +36,8 @@ func getDefaultConfig(networkInfo *network.NetworkInfo) *config.Config { // CreateNewAppCore creates and returns a new instance of *core.Core func CreateNewAppCore(networkInfo *network.NetworkInfo) (*Core, error) { - configRepo := config.NewJSONRepo() + configPath := viper.Get("config-path").(string) + configRepo := config.NewJSONRepo(configPath) configService := config.NewConfigService(configRepo) conf, err := configService.GetByCIDR(networkInfo.Cidr) diff --git a/internal/discovery/service.go b/internal/discovery/service.go index b8b1540..9d274b1 100644 --- a/internal/discovery/service.go +++ b/internal/discovery/service.go @@ -100,7 +100,7 @@ func (s *ScannerService) pollNetwork() { Status: PortClosed, }, } - s.handleDiscoveryResult(dr) + go s.handleDiscoveryResult(dr) case scanner.SYNResult: res := r.Payload.(*scanner.SynScanResult) dr := &DiscoveryResult{ @@ -115,7 +115,7 @@ func (s *ScannerService) pollNetwork() { Status: PortStatus(res.Port.Status), }, } - s.handleDiscoveryResult(dr) + go s.handleDiscoveryResult(dr) } case err := <-s.errorChan: s.log.Error().Err(err).Msg("discovery service encountered an error") diff --git a/internal/discovery/uname.go b/internal/discovery/uname.go index 5b2d850..c33747f 100644 --- a/internal/discovery/uname.go +++ b/internal/discovery/uname.go @@ -25,6 +25,7 @@ func NewUnameScanner(conf config.Config) *UnameScanner { func (s UnameScanner) GetServerDetails(ctx context.Context, ip string) (*Details, error) { user := s.conf.SSH.User identity := s.conf.SSH.Identity + port := s.conf.SSH.Port for _, o := range s.conf.SSH.Overrides { if o.Target == ip { @@ -35,12 +36,30 @@ func (s UnameScanner) GetServerDetails(ctx context.Context, ip string) (*Details if o.Identity != "" { identity = o.Identity } + + if o.Port != "" { + port = o.Port + } } } - cmd := exec.Command("ssh", "-i", identity, user+"@"+ip, "uname -a") - - unameOutput, err := cmd.Output() + unameCmd := exec.Command( + "ssh", + "-i", + identity, + "-p", + port, + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=no", + "-l", + user, + ip, + "uname -a", + ) + + unameOutput, err := unameCmd.Output() if err != nil { return nil, err @@ -48,16 +67,30 @@ func (s UnameScanner) GetServerDetails(ctx context.Context, ip string) (*Details info := strings.Split(string(unameOutput), " ") - os := info[0] + operatingSystem := info[0] hostname := info[1] - switch os { + switch operatingSystem { case "Darwin": - os = "MacOS" + operatingSystem = "MacOS" case "Linux": - cmd = exec.Command("ssh", "-i", identity, user+"@"+ip, "cat /etc/os-release") - - osReleaseOutput, err := cmd.Output() + osReleaseCmd := exec.Command( + "ssh", + "-i", + identity, + "-p", + port, + "-o", + "BatchMode=yes", + "-o", + "StrictHostKeyChecking=no", + "-l", + user, + ip, + "cat /etc/os-release", + ) + + osReleaseOutput, err := osReleaseCmd.Output() if err != nil { return nil, err @@ -67,13 +100,13 @@ func (s UnameScanner) GetServerDetails(ctx context.Context, ip string) (*Details for i, name := range osReleaseRegexp.SubexpNames() { if name == "os" { - os = match[i] + operatingSystem = match[i] } } } return &Details{ Hostname: hostname, - OS: os, + OS: operatingSystem, }, nil } diff --git a/internal/ui/component/configure.go b/internal/ui/component/configure.go index 088f534..8f4c23a 100644 --- a/internal/ui/component/configure.go +++ b/internal/ui/component/configure.go @@ -13,6 +13,7 @@ type ConfigureForm struct { configName *tview.InputField sshUserInput *tview.InputField sshIdentityInput *tview.InputField + sshPortInput *tview.InputField cidrInput *tview.InputField overrides []map[string]*tview.InputField conf config.Config @@ -26,7 +27,7 @@ type ConfigureForm struct { func addBlankFormItems( form *tview.Form, confName string, -) (*tview.InputField, *tview.InputField, *tview.InputField, *tview.InputField) { +) (*tview.InputField, *tview.InputField, *tview.InputField, *tview.InputField, *tview.InputField) { configName := tview.NewInputField() configName.SetLabel("Config Name: ") @@ -36,6 +37,9 @@ func addBlankFormItems( sshIdentityInput := tview.NewInputField() sshIdentityInput.SetLabel("SSH Identity: ") + sshPortInput := tview.NewInputField() + sshPortInput.SetLabel("SSH Port: ") + cidrInput := tview.NewInputField() cidrInput.SetLabel("Network CIDR: ") @@ -43,6 +47,7 @@ func addBlankFormItems( form.AddFormItem(cidrInput) form.AddFormItem(sshUserInput) form.AddFormItem(sshIdentityInput) + form.AddFormItem(sshPortInput) form.SetTitle(confName + " Configuration") form.SetBorder(true) @@ -55,11 +60,11 @@ func addBlankFormItems( style.StyleDefault.Background(style.ColorLightGreen), ) - return configName, sshUserInput, sshIdentityInput, cidrInput + return configName, sshUserInput, sshIdentityInput, sshPortInput, cidrInput } // every time the add ssh override button is clicked we add three new inputs -func createOverrideInputs(conf config.Config) (*tview.InputField, *tview.InputField, *tview.InputField) { +func createOverrideInputs(conf config.Config) (*tview.InputField, *tview.InputField, *tview.InputField, *tview.InputField) { overrideTarget := tview.NewInputField() overrideTarget.SetLabel("Override Target IP: ") @@ -71,7 +76,11 @@ func createOverrideInputs(conf config.Config) (*tview.InputField, *tview.InputFi overrideSSHIdentity.SetLabel("Override SSH Identity: ") overrideSSHIdentity.SetText(conf.SSH.Identity) - return overrideTarget, overrideSSHUser, overrideSSHIdentity + overrideSSHPort := tview.NewInputField() + overrideSSHPort.SetLabel("Override SSH Port: ") + overrideSSHPort.SetText(conf.SSH.Port) + + return overrideTarget, overrideSSHUser, overrideSSHIdentity, overrideSSHPort } // NewConfigureForm returns a new instance of ConfigureForm @@ -83,7 +92,7 @@ func NewConfigureForm( ) *ConfigureForm { form := tview.NewForm() - configName, sshUserInput, sshIdentityInput, cidrInput := addBlankFormItems( + configName, sshUserInput, sshIdentityInput, sshPortInput, cidrInput := addBlankFormItems( form, conf.Name, ) @@ -93,6 +102,7 @@ func NewConfigureForm( configName: configName, sshUserInput: sshUserInput, sshIdentityInput: sshIdentityInput, + sshPortInput: sshPortInput, cidrInput: cidrInput, overrides: []map[string]*tview.InputField{}, conf: conf, @@ -114,7 +124,7 @@ func (f *ConfigureForm) render() { f.root.Clear(true) f.overrides = []map[string]*tview.InputField{} - f.configName, f.sshUserInput, f.sshIdentityInput, f.cidrInput = + f.configName, f.sshUserInput, f.sshIdentityInput, f.sshPortInput, f.cidrInput = addBlankFormItems(f.root, f.conf.Name) networkTargets := f.conf.CIDR @@ -122,22 +132,29 @@ func (f *ConfigureForm) render() { f.configName.SetText(f.conf.Name) f.sshUserInput.SetText(f.conf.SSH.User) f.sshIdentityInput.SetText(f.conf.SSH.Identity) + f.sshPortInput.SetText(f.conf.SSH.Port) f.cidrInput.SetText(networkTargets) for _, o := range f.conf.SSH.Overrides { - target, user, identity := createOverrideInputs(f.conf) + target, user, identity, port := createOverrideInputs(f.conf) f.overrides = append(f.overrides, map[string]*tview.InputField{ "target": target, "user": user, "identity": identity, + "port": port, }) target.SetText(o.Target) user.SetText(o.User) identity.SetText(o.Identity) + port.SetText(o.Port) - f.root.AddFormItem(target).AddFormItem(user).AddFormItem(identity) + f.root. + AddFormItem(target). + AddFormItem(user). + AddFormItem(identity). + AddFormItem(port) } f.addFormButtons() @@ -156,15 +173,20 @@ func (f *ConfigureForm) addFormButtons() { }) f.root.AddButton("Add SSH Override", func() { - target, user, identity := createOverrideInputs(f.conf) + target, user, identity, port := createOverrideInputs(f.conf) f.overrides = append(f.overrides, map[string]*tview.InputField{ "target": target, "user": user, "identity": identity, + "port": port, }) - f.root.AddFormItem(target).AddFormItem(user).AddFormItem(identity) + f.root. + AddFormItem(target). + AddFormItem(user). + AddFormItem(identity). + AddFormItem(port) }) f.root.AddButton("New", func() { @@ -181,6 +203,7 @@ func (f *ConfigureForm) addFormButtons() { f.cidrInput.SetText("") f.sshUserInput.SetText("") f.sshIdentityInput.SetText("") + f.sshPortInput.SetText("") f.creatingNewConfig = true }) @@ -189,8 +212,9 @@ func (f *ConfigureForm) addFormButtons() { cidr := f.cidrInput.GetText() sshUser := f.sshUserInput.GetText() sshIdentity := f.sshIdentityInput.GetText() + sshPort := f.sshPortInput.GetText() - if name == "" || cidr == "" || sshUser == "" || sshIdentity == "" { + if name == "" || cidr == "" || sshUser == "" || sshIdentity == "" || sshPort == "" { f.creatingNewConfig = false return } @@ -202,6 +226,7 @@ func (f *ConfigureForm) addFormButtons() { Target: o["target"].GetText(), User: o["user"].GetText(), Identity: o["identity"].GetText(), + Port: o["port"].GetText(), } confOverrides = append(confOverrides, confOverride) @@ -212,6 +237,7 @@ func (f *ConfigureForm) addFormButtons() { SSH: config.SSHConfig{ User: sshUser, Identity: sshIdentity, + Port: sshPort, Overrides: confOverrides, }, CIDR: cidr, diff --git a/internal/ui/view.go b/internal/ui/view.go index a80424b..c63b5bf 100644 --- a/internal/ui/view.go +++ b/internal/ui/view.go @@ -368,6 +368,7 @@ func (v *view) onSSH(ip string) { conf := v.appCore.Conf() user := conf.SSH.User identity := conf.SSH.Identity + port := conf.SSH.Port for _, o := range conf.SSH.Overrides { if o.Target == ip { @@ -378,10 +379,23 @@ func (v *view) onSSH(ip string) { if o.Identity != "" { identity = o.Identity } + + if o.Port != "" { + port = o.Port + } } } - cmd := exec.Command("ssh", "-i", identity, user+"@"+ip) + cmd := exec.Command( + "ssh", + "-i", + identity, + "-p", + port, + "-l", + user, + ip, + ) restoreStdout()