Skip to content

Commit

Permalink
Add optional uaa-target flag
Browse files Browse the repository at this point in the history
- The UAA target defaults to /uaa on Operations Manager, but can now be
overridden in cases where UAA is not located at the default location.
- Consistently parse target and allow http(s):// protocol prefix.
  • Loading branch information
sneal authored and Ryan Hall committed Jul 23, 2024
1 parent 9eedd62 commit 5402fc2
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 66 deletions.
9 changes: 5 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ type options struct {
ClientID string `yaml:"client-id" short:"c" long:"client-id" env:"OM_CLIENT_ID" description:"Client ID for the Ops Manager VM (not required for unauthenticated commands)"`
ClientSecret string `yaml:"client-secret" short:"s" long:"client-secret" env:"OM_CLIENT_SECRET" description:"Client Secret for the Ops Manager VM (not required for unauthenticated commands)"`
ConnectTimeout int `yaml:"connect-timeout" short:"o" long:"connect-timeout" env:"OM_CONNECT_TIMEOUT" default:"10" description:"timeout in seconds to make TCP connections"`
DecryptionPassphrase string `yaml:"decryption-passphrase" short:"d" long:"decryption-passphrase" env:"OM_DECRYPTION_PASSPHRASE" description:"Passphrase to decrypt the installation if the Ops Manager VM has been rebooted (optional for most commands)"`
Env string ` short:"e" long:"env" description:"env file with login credentials"`
DecryptionPassphrase string `yaml:"decryption-passphrase" short:"d" long:"decryption-passphrase" env:"OM_DECRYPTION_PASSPHRASE" description:"Passphrase to decrypt the installation if the Ops Manager VM has been rebooted (optional for most commands)"`
Env string ` short:"e" long:"env" description:"env file with login credentials"`
Password string `yaml:"password" short:"p" long:"password" env:"OM_PASSWORD" description:"admin password for the Ops Manager VM (not required for unauthenticated commands)"`
RequestTimeout int `yaml:"request-timeout" short:"r" long:"request-timeout" env:"OM_REQUEST_TIMEOUT" default:"1800" description:"timeout in seconds for HTTP requests to Ops Manager"`
SkipSSLValidation bool `yaml:"skip-ssl-validation" short:"k" long:"skip-ssl-validation" env:"OM_SKIP_SSL_VALIDATION" description:"skip ssl certificate validation during http requests"`
Target string `yaml:"target" short:"t" long:"target" env:"OM_TARGET" description:"location of the Ops Manager VM"`
UAATarget string `yaml:"uaa-target" long:"uaa-target" env:"OM_UAA_TARGET" description:"optional location of the Ops Manager UAA"`
Trace bool `yaml:"trace" long:"trace" env:"OM_TRACE" description:"prints HTTP requests and response payloads"`
Username string `yaml:"username" short:"u" long:"username" env:"OM_USERNAME" description:"admin username for the Ops Manager VM (not required for unauthenticated commands)"`
VarsEnv string ` long:"vars-env" env:"OM_VARS_ENV" description:"load vars from environment variables by specifying a prefix (e.g.: 'MY' to load MY_var=value)"`
Expand Down Expand Up @@ -79,8 +80,7 @@ func Main(sout io.Writer, serr io.Writer, version string, applySleepDurationStri
return err
}

authedClient, err = network.NewOAuthClient(global.Target, global.Username, global.Password, global.ClientID, global.ClientSecret, global.SkipSSLValidation, global.CACert, connectTimeout, requestTimeout)

authedClient, err = network.NewOAuthClient(global.UAATarget, global.Target, global.Username, global.Password, global.ClientID, global.ClientSecret, global.SkipSSLValidation, global.CACert, connectTimeout, requestTimeout)
if err != nil {
return err
}
Expand Down Expand Up @@ -663,6 +663,7 @@ func Main(sout io.Writer, serr io.Writer, version string, applySleepDurationStri
}
}
}

return err
}

Expand Down
29 changes: 11 additions & 18 deletions network/oauth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package network
import (
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/cloudfoundry-community/go-uaa"
Expand All @@ -17,15 +15,17 @@ type OAuthClient struct {
clientSecret string
insecureSkipVerify bool
password string
target string
opsmanTarget string
uaaTarget string
token *oauth2.Token
username string
connectTimeout time.Duration
requestTimeout time.Duration
}

func NewOAuthClient(
target, username, password string,
uaaTarget, opsmanTarget string,
username, password string,
clientID, clientSecret string,
insecureSkipVerify bool,
caCert string,
Expand All @@ -38,7 +38,8 @@ func NewOAuthClient(
clientSecret: clientSecret,
insecureSkipVerify: insecureSkipVerify,
password: password,
target: target,
uaaTarget: uaaTarget,
opsmanTarget: opsmanTarget,
username: username,
connectTimeout: connectTimeout,
requestTimeout: requestTimeout,
Expand All @@ -47,21 +48,13 @@ func NewOAuthClient(

func (oc *OAuthClient) Do(request *http.Request) (*http.Response, error) {
token := oc.token
target := oc.target

if !strings.HasPrefix(target, "http://") && !strings.HasPrefix(target, "https://") {
target = "https://" + target
}

targetURL, err := url.Parse(target)
opsmanTarget, uaaTarget, err := parseOpsmanAndUAAURLs(oc.opsmanTarget, oc.uaaTarget)
if err != nil {
return nil, fmt.Errorf("could not parse target url: %s", err)
return nil, err
}

targetURL.Path = "/uaa"

request.URL.Scheme = targetURL.Scheme
request.URL.Host = targetURL.Host
request.URL.Scheme = opsmanTarget.Scheme
request.URL.Host = opsmanTarget.Host

client, err := newHTTPClient(
oc.insecureSkipVerify,
Expand Down Expand Up @@ -106,7 +99,7 @@ func (oc *OAuthClient) Do(request *http.Request) (*http.Response, error) {
}

api, err := uaa.New(
targetURL.String(),
uaaTarget.String(),
authOption,
options...,
)
Expand Down
32 changes: 16 additions & 16 deletions network/oauth_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ var _ = Describe("OAuthClient", func() {
Describe("Do", func() {
When("with a request timeout", func() {
It("use that timeout value", func() {
client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Nanosecond, time.Nanosecond)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Nanosecond, time.Nanosecond)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand Down Expand Up @@ -61,7 +61,7 @@ var _ = Describe("OAuthClient", func() {
ghttp.RespondWith(http.StatusOK, nil),
)

client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(100)*time.Millisecond, time.Duration(100)*time.Millisecond)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(100)*time.Millisecond, time.Duration(100)*time.Millisecond)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand Down Expand Up @@ -93,7 +93,7 @@ var _ = Describe("OAuthClient", func() {
ghttp.RespondWith(http.StatusOK, nil),
)

client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(100)*time.Millisecond, time.Duration(100)*time.Millisecond)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(100)*time.Millisecond, time.Duration(100)*time.Millisecond)
Expect(err).ToNot(HaveOccurred())

for i := 0; i < 2; i++ {
Expand All @@ -120,7 +120,7 @@ var _ = Describe("OAuthClient", func() {
ghttp.RespondWith(http.StatusOK, ""),
)

client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(100)*time.Millisecond, time.Duration(100)*time.Millisecond)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(100)*time.Millisecond, time.Duration(100)*time.Millisecond)
Expect(err).ToNot(HaveOccurred())

for i := 0; i < 2; i++ {
Expand Down Expand Up @@ -153,7 +153,7 @@ var _ = Describe("OAuthClient", func() {
))
server.AppendHandlers(ghttp.RespondWith(http.StatusOK, nil))

client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand Down Expand Up @@ -182,7 +182,7 @@ var _ = Describe("OAuthClient", func() {
))
server.AppendHandlers(ghttp.RespondWith(http.StatusOK, nil))

client, err := network.NewOAuthClient(server.URL(), "", "", "client_id", "client_secret", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", server.URL(), "", "", "client_id", "client_secret", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand All @@ -200,7 +200,7 @@ var _ = Describe("OAuthClient", func() {
nonTLS12Server.Config.ErrorLog = log.New(GinkgoWriter, "", 0)
defer nonTLS12Server.Close()

client, err := network.NewOAuthClient(nonTLS12Server.URL, "", "", "client_id", "client_secret", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", nonTLS12Server.URL, "", "", "client_id", "client_secret", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand All @@ -220,7 +220,7 @@ var _ = Describe("OAuthClient", func() {
noScheme.Scheme = ""
finalURL := noScheme.String()[2:] // removing leading "//"

client, err := network.NewOAuthClient(finalURL, "opsman-username", "opsman-password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", finalURL, "opsman-username", "opsman-password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand All @@ -236,7 +236,7 @@ var _ = Describe("OAuthClient", func() {
When("insecureSkipVerify is configured", func() {
When("it is set to false", func() {
It("throws an error for invalid certificates", func() {
client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", false, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", false, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand All @@ -251,7 +251,7 @@ var _ = Describe("OAuthClient", func() {
It("does not verify certificates", func() {
setupBasicOauth(server)

client, err := network.NewOAuthClient(server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", server.URL(), "opsman-username", "opsman-password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand All @@ -272,7 +272,7 @@ var _ = Describe("OAuthClient", func() {
pemCert := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}))

client, err := network.NewOAuthClient(
server.URL(),
"", server.URL(),
"opsman-username", "opsman-password",
"", "",
false,
Expand All @@ -297,7 +297,7 @@ var _ = Describe("OAuthClient", func() {
pemCert := writeFile(string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})))

client, err := network.NewOAuthClient(
server.URL(),
"", server.URL(),
"opsman-username", "opsman-password",
"", "",
false,
Expand Down Expand Up @@ -327,7 +327,7 @@ var _ = Describe("OAuthClient", func() {
})

It("returns an error", func() {
client, err := network.NewOAuthClient(badServer.URL, "username", "password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", badServer.URL, "username", "password", "", "", true, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expand All @@ -338,16 +338,16 @@ var _ = Describe("OAuthClient", func() {
})
})

When("the target url is empty", func() {
When("the UAA and Opsman target url are empty", func() {
It("returns an error", func() {
client, err := network.NewOAuthClient("", "username", "password", "", "", false, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
client, err := network.NewOAuthClient("", "", "username", "password", "", "", false, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
Expect(err).ToNot(HaveOccurred())

req, err := http.NewRequest("GET", "/some/path", strings.NewReader("request-body"))
Expect(err).ToNot(HaveOccurred())

_, err = client.Do(req)
Expect(err).To(MatchError(ContainSubstring("")))
Expect(err).To(MatchError(ContainSubstring("could not parse Opsman target URL")))
})
})
})
Expand Down
21 changes: 2 additions & 19 deletions network/unauthenticated_client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
package network

import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)

Expand All @@ -27,22 +23,9 @@ func NewUnauthenticatedClient(target string, insecureSkipVerify bool, caCert str
}

func (c UnauthenticatedClient) Do(request *http.Request) (*http.Response, error) {
candidateURL := c.target
if !strings.Contains(candidateURL, "//") {
candidateURL = fmt.Sprintf("//%s", candidateURL)
}

targetURL, err := url.Parse(candidateURL)
targetURL, err := parseURL(c.target)
if err != nil {
return nil, fmt.Errorf("could not parse target url: %s", err)
}

if targetURL.Scheme == "" {
targetURL.Scheme = "https"
}

if targetURL.Host == "" {
return nil, errors.New("target flag is required. Run `om help` for more info.")
return nil, err
}

request.URL.Scheme = targetURL.Scheme
Expand Down
10 changes: 1 addition & 9 deletions network/unauthenticated_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,11 @@ var _ = Describe("UnauthenticatedClient", func() {
})

Context("failure cases", func() {
When("the target url cannot be parsed", func() {
It("returns an error", func() {
client, _ := network.NewUnauthenticatedClient("%%%", false, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
_, err := client.Do(&http.Request{})
Expect(err).To(MatchError("could not parse target url: parse \"//%%%\": invalid URL escape \"%%%\""))
})
})

When("the target url is empty", func() {
It("returns an error", func() {
client, _ := network.NewUnauthenticatedClient("", false, "", time.Duration(5)*time.Second, time.Duration(30)*time.Second)
_, err := client.Do(&http.Request{})
Expect(err).To(MatchError("target flag is required. Run `om help` for more info."))
Expect(err).To(MatchError("target flag is required, run `om help` for more info"))
})
})
})
Expand Down
55 changes: 55 additions & 0 deletions network/url.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package network

import (
"errors"
"fmt"
"net/url"
"strings"
)

// parseURL takes a candidate target URL and attempts to parse it with sane defaults
func parseURL(u string) (*url.URL, error) {
// default the target protocol to https if none specified
if !strings.Contains(u, "://") {
u = "https://" + u
}

targetURL, err := url.Parse(u)
if err != nil {
return nil, err
}

// at a minimum ensure we have a host with http(s) protocol
if targetURL.Scheme != "https" && targetURL.Scheme != "http" {
return nil, fmt.Errorf("error parsing URL, expected http(s) protocol but got %s", targetURL.Scheme)
}
if targetURL.Host == "" {
return nil, errors.New("target flag is required, run `om help` for more info")
}

return targetURL, nil
}

// parseOpsmanAndUAAURLs takes a candidate OpsMan and UAA target URLs and attempts to parse both of them, defaulting
// the UAA target to the /uaa path under the OpsMan target if none specified.
func parseOpsmanAndUAAURLs(opsmanTarget, uaaTarget string) (*url.URL, *url.URL, error) {
opsmanURL, err := parseURL(opsmanTarget)
if err != nil {
return nil, nil, fmt.Errorf("could not parse Opsman target URL: %w", err)
}

var uaaURL *url.URL
if uaaTarget != "" {
uaaURL, err = parseURL(uaaTarget)
if err != nil {
return nil, nil, fmt.Errorf("could not parse UAA target URL: %w", err)
}
} else {
// default to opsman URL with /uaa path (shallow copy)
t := *opsmanURL
t.Path = "/uaa"
uaaURL = &t
}

return opsmanURL, uaaURL, nil
}
Loading

0 comments on commit 5402fc2

Please sign in to comment.