Skip to content

Commit

Permalink
Escape values of extra credentials and session properties
Browse files Browse the repository at this point in the history
Correctly escape values of extra credentials and session properties when
building HTTP headers.

Change the format of extra credentials and session properties in the URL
to match the Trino JDBC and Python drivers - keys and values are
separated with a colon (`:`) and multiple entries are separated with a
semi-colon (`;`). Note: all special characters, including the semicolon,
must be URL escaped.
  • Loading branch information
nineinchnick committed Oct 6, 2024
1 parent d71f0cb commit a3de405
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 20 deletions.
2 changes: 1 addition & 1 deletion trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ handleErr:

func TestIntegrationSessionProperties(t *testing.T) {
dsn := *integrationServerFlag
dsn += "?session_properties=query_max_run_time=10m,query_priority=2"
dsn += "?session_properties=query_max_run_time%3A10m%3Bquery_priority%3A2"
db := integrationOpen(t, dsn)
defer db.Close()
rows, err := db.Query("SHOW SESSION")
Expand Down
71 changes: 60 additions & 11 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ const (
sslCertPathConfig = "SSLCertPath"
sslCertConfig = "SSLCert"
accessTokenConfig = "accessToken"

mapKeySeparator = ":"
mapEntrySeparator = ";"
)

var (
Expand Down Expand Up @@ -191,13 +194,13 @@ func (c *Config) FormatDSN() (string, error) {
var sessionkv []string
if c.SessionProperties != nil {
for k, v := range c.SessionProperties {
sessionkv = append(sessionkv, k+"="+v)
sessionkv = append(sessionkv, k+mapKeySeparator+v)
}
}
var credkv []string
if c.ExtraCredentials != nil {
for k, v := range c.ExtraCredentials {
credkv = append(credkv, k+"="+v)
credkv = append(credkv, k+mapKeySeparator+v)
}
}
source := c.Source
Expand Down Expand Up @@ -258,8 +261,8 @@ func (c *Config) FormatDSN() (string, error) {
for k, v := range map[string]string{
"catalog": c.Catalog,
"schema": c.Schema,
"session_properties": strings.Join(sessionkv, ","),
"extra_credentials": strings.Join(credkv, ","),
"session_properties": strings.Join(sessionkv, mapEntrySeparator),
"extra_credentials": strings.Join(credkv, mapEntrySeparator),
"custom_client": c.CustomClientName,
accessTokenConfig: c.AccessToken,
} {
Expand Down Expand Up @@ -368,22 +371,68 @@ func newConn(dsn string) (*Conn, error) {
}

for k, v := range map[string]string{
trinoUserHeader: user,
trinoSourceHeader: query.Get("source"),
trinoCatalogHeader: query.Get("catalog"),
trinoSchemaHeader: query.Get("schema"),
trinoSessionHeader: query.Get("session_properties"),
trinoExtraCredentialHeader: query.Get("extra_credentials"),
authorizationHeader: getAuthorization(query.Get(accessTokenConfig)),
trinoUserHeader: user,
trinoSourceHeader: query.Get("source"),
trinoCatalogHeader: query.Get("catalog"),
trinoSchemaHeader: query.Get("schema"),
authorizationHeader: getAuthorization(query.Get(accessTokenConfig)),
} {
if v != "" {
c.httpHeaders.Add(k, v)
}
}
for header, param := range map[string]string{
trinoSessionHeader: "session_properties",
trinoExtraCredentialHeader: "extra_credentials",
} {
v := query.Get(param)
if v != "" {
c.httpHeaders[header], err = decodeMapHeader(param, v)
if err != nil {
return c, err
}
}
}

return c, nil
}

func decodeMapHeader(name, input string) ([]string, error) {
result := []string{}
for _, entry := range strings.Split(input, mapEntrySeparator) {
parts := strings.SplitN(entry, mapKeySeparator, 2)
if len(parts) != 2 {
return nil, fmt.Errorf("trino: Malformed %s: %s", name, input)
}
key := parts[0]
value := parts[1]
if len(key) == 0 {
return nil, fmt.Errorf("trino: %s key is empty", name)
}
if len(value) == 0 {
return nil, fmt.Errorf("trino: %s value is empty", name)
}
if !isASCII(key) {
return nil, fmt.Errorf("trino: %s key '%s' contains spaces or is not printable ASCII", name, key)
}
if !isASCII(value) {
// do not log value as it may contain sensitive information
return nil, fmt.Errorf("trino: %s value for key '%s' contains spaces or is not printable ASCII", name, key)
}
result = append(result, key+"="+url.QueryEscape(value))
}
return result, nil
}

func isASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] < '\u0021' || s[i] > '\u007E' {
return false
}
}
return true
}

func getAuthorization(token string) string {
if token == "" {
return ""
Expand Down
60 changes: 52 additions & 8 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestConfig(t *testing.T) {
dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "http://foobar@localhost:8080?session_properties=query_priority%3D1&source=trino-go-client"
want := "http://foobar@localhost:8080?session_properties=query_priority%3A1&source=trino-go-client"

assert.Equal(t, want, dsn)
}
Expand All @@ -58,7 +58,7 @@ func TestConfigSSLCertPath(t *testing.T) {
dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3D1&source=trino-go-client"
want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3A1&source=trino-go-client"

assert.Equal(t, want, dsn)
}
Expand Down Expand Up @@ -105,25 +105,69 @@ FKu5ZAlRfb2aYegr49DHhzoVAdInWQmP+5EZEUD1
dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "https://foobar@localhost:8080?SSLCert=" + url.QueryEscape(sslCert) + "&session_properties=query_priority%3D1&source=trino-go-client"
want := "https://foobar@localhost:8080?SSLCert=" + url.QueryEscape(sslCert) + "&session_properties=query_priority%3A1&source=trino-go-client"

assert.Equal(t, want, dsn)
}

func TestExtraCredentials(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8080",
ExtraCredentials: map[string]string{"token": "mYtOkEn", "otherToken": "oThErToKeN"},
ExtraCredentials: map[string]string{"token": "mYtOkEn", "otherToken": "oThErToKeN%*!#@special"},
}

dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "http://foobar@localhost:8080?extra_credentials=otherToken%3DoThErToKeN%2Ctoken%3DmYtOkEn&source=trino-go-client"

want := "http://foobar@localhost:8080?extra_credentials=otherToken%3AoThErToKeN%25%2A%21%23%40special%3Btoken%3AmYtOkEn&source=trino-go-client"
assert.Equal(t, want, dsn)
}

func TestInvalidExtraCredentials(t *testing.T) {
testcases := []struct {
Name string
Credentials map[string]string
Error string
}{
{
Name: "Empty key",
Credentials: map[string]string{"": "emptyKey"},
Error: "trino: extra_credentials key is empty",
},
{
Name: "Empty value",
Credentials: map[string]string{"valid": "a", "emptyValue": ""},
Error: "trino: extra_credentials value is empty",
},
{
Name: "Unprintable key",
Credentials: map[string]string{"😊": "unprintableKey"},
Error: "trino: extra_credentials key '😊' contains spaces or is not printable ASCII",
},
{
Name: "Unprintable value",
Credentials: map[string]string{"unprintableValue": "😊"},
Error: "trino: extra_credentials value for key 'unprintableValue' contains spaces or is not printable ASCII",
},
}

for _, tc := range testcases {

t.Run(tc.Name, func(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8080",
ExtraCredentials: tc.Credentials,
}
dsn, err := c.FormatDSN()
require.NoError(t, err)
db, err := sql.Open("trino", dsn)
require.NoError(t, err)
err = db.Ping()
assert.EqualError(t, err, tc.Error)
})
}
}

func TestConfigWithoutSSLCertPath(t *testing.T) {
c := &Config{
ServerURI: "https://foobar@localhost:8080",
Expand All @@ -132,7 +176,7 @@ func TestConfigWithoutSSLCertPath(t *testing.T) {
dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "https://foobar@localhost:8080?session_properties=query_priority%3D1&source=trino-go-client"
want := "https://foobar@localhost:8080?session_properties=query_priority%3A1&source=trino-go-client"

assert.Equal(t, want, dsn)
}
Expand All @@ -153,7 +197,7 @@ func TestKerberosConfig(t *testing.T) {
dsn, err := c.FormatDSN()
require.NoError(t, err)

want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&KerberosRemoteServiceName=service&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3D1&source=trino-go-client"
want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&KerberosRemoteServiceName=service&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3A1&source=trino-go-client"

assert.Equal(t, want, dsn)
}
Expand Down

0 comments on commit a3de405

Please sign in to comment.