Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optionally serve OAuth 2.0 authentication with HTTPS #38

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions oauth/authcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@ import (
"bufio"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"os"
"os/exec"
"path"
"runtime"
"strconv"
"strings"
"time"

"context"

"github.com/danielgtaylor/restish/cli"
"github.com/spf13/viper"
"golang.org/x/oauth2"
)

Expand Down Expand Up @@ -178,7 +182,7 @@ func (h authHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// AuthorizationCodeTokenSource with PKCE as described in:
// https://www.oauth.com/oauth2-servers/pkce/
// This works by running a local HTTP server on port 8484 and then having the
// This works by running a local HTTP or HTTPS server on port 8484 and then having the
// user log in through a web browser, which redirects to the local server with
// an authorization code. That code is then used to make another HTTP request
// to fetch an auth token (and refresh token). That token is then in turn
Expand All @@ -190,6 +194,7 @@ type AuthorizationCodeTokenSource struct {
TokenURL string
EndpointParams *url.Values
Scopes []string
UseHTTPS bool
}

// Token generates a new token using an authorization code.
Expand All @@ -213,12 +218,22 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
panic(err)
}

redirectURL := url.URL{
Host: "localhost:8484",
Path: "/",
}
if ac.UseHTTPS {
redirectURL.Scheme = "https"
} else {
redirectURL.Scheme = "http"
}

aq := authorizeURL.Query()
aq.Set("response_type", "code")
aq.Set("code_challenge", challenge)
aq.Set("code_challenge_method", "S256")
aq.Set("client_id", ac.ClientID)
aq.Set("redirect_uri", "http://localhost:8484/")
aq.Set("redirect_uri", redirectURL.String())
aq.Set("scope", strings.Join(ac.Scopes, " "))
if ac.EndpointParams != nil {
for k, v := range *ac.EndpointParams {
Expand All @@ -234,16 +249,38 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
}

s := &http.Server{
Addr: "localhost:8484",
Addr: redirectURL.Host,
Handler: handler,
ReadTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
MaxHeaderBytes: 1024,
}

if ac.UseHTTPS {
configDirectory := viper.GetString("config-directory")
certName := path.Join(configDirectory, "localhost.crt")
keyfileName := path.Join(configDirectory, "localhost.key")

cert, err := tls.LoadX509KeyPair(certName, keyfileName)
if err != nil {
panic(err)
}

s.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
}
}

go func() {
// Run in a goroutine until the server is closed or we get an error.
if err := s.ListenAndServe(); err != http.ErrServerClosed {
var err error
if ac.UseHTTPS {
err = s.ListenAndServeTLS("", "")
} else {
err = s.ListenAndServe()
}

if err != http.ErrServerClosed {
panic(err)
}
}()
Expand Down Expand Up @@ -279,7 +316,7 @@ func (ac *AuthorizationCodeTokenSource) Token() (*oauth2.Token, error) {
payload.Set("client_id", ac.ClientID)
payload.Set("code_verifier", verifier)
payload.Set("code", code)
payload.Set("redirect_uri", "http://localhost:8484/")
payload.Set("redirect_uri", redirectURL.String())
if ac.ClientSecret != "" {
payload.Set("client_secret", ac.ClientSecret)
}
Expand All @@ -299,6 +336,7 @@ func (h *AuthorizationCodeHandler) Parameters() []cli.AuthParam {
{Name: "authorize_url", Required: true, Help: "OAuth 2.0 authorization URL, e.g. https://api.example.com/oauth/authorize"},
{Name: "token_url", Required: true, Help: "OAuth 2.0 token URL, e.g. https://api.example.com/oauth/token"},
{Name: "scopes", Help: "Optional scopes to request in the token"},
{Name: "use_https", Help: "Use HTTPS for authentication page"},
}
}

Expand All @@ -307,21 +345,31 @@ func (h *AuthorizationCodeHandler) OnRequest(request *http.Request, key string,
if request.Header.Get("Authorization") == "" {
endpointParams := url.Values{}
for k, v := range params {
if k == "client_id" || k == "client_secret" || k == "scopes" || k == "authorize_url" || k == "token_url" {
if k == "client_id" || k == "client_secret" || k == "scopes" || k == "authorize_url" || k == "token_url" || k == "use_https" {
// Not a custom param...
continue
}

endpointParams.Add(k, v)
}

var useHTTPS bool
if v := params["use_https"]; v != "" {
var err error
useHTTPS, err = strconv.ParseBool(v)
if err != nil {
return err
}
}

source := &AuthorizationCodeTokenSource{
ClientID: params["client_id"],
ClientSecret: params["client_secret"],
AuthorizeURL: params["authorize_url"],
TokenURL: params["token_url"],
EndpointParams: &endpointParams,
Scopes: strings.Split(params["scopes"], ","),
UseHTTPS: useHTTPS,
}

// Try to get a cached refresh token from the current profile and use
Expand Down