diff --git a/trino/integration_test.go b/trino/integration_test.go index d78c1ae..86e17d5 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -496,7 +496,7 @@ handleErr: func TestIntegrationSessionProperties(t *testing.T) { dsn := *integrationServerFlag - dsn += "?session_properties=query_max_run_time=10m,query_priority=2" + dsn += "?session_properties=query_max_run_time%3A10m%3Bquery_priority%3A2" db := integrationOpen(t, dsn) defer db.Close() rows, err := db.Query("SHOW SESSION") diff --git a/trino/trino.go b/trino/trino.go index 5204717..48d74c6 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -141,6 +141,9 @@ const ( sslCertPathConfig = "SSLCertPath" sslCertConfig = "SSLCert" accessTokenConfig = "accessToken" + + mapKeySeparator = ":" + mapEntrySeparator = ";" ) var ( @@ -191,13 +194,13 @@ func (c *Config) FormatDSN() (string, error) { var sessionkv []string if c.SessionProperties != nil { for k, v := range c.SessionProperties { - sessionkv = append(sessionkv, k+"="+v) + sessionkv = append(sessionkv, k+mapKeySeparator+v) } } var credkv []string if c.ExtraCredentials != nil { for k, v := range c.ExtraCredentials { - credkv = append(credkv, k+"="+v) + credkv = append(credkv, k+mapKeySeparator+v) } } source := c.Source @@ -258,8 +261,8 @@ func (c *Config) FormatDSN() (string, error) { for k, v := range map[string]string{ "catalog": c.Catalog, "schema": c.Schema, - "session_properties": strings.Join(sessionkv, ","), - "extra_credentials": strings.Join(credkv, ","), + "session_properties": strings.Join(sessionkv, mapEntrySeparator), + "extra_credentials": strings.Join(credkv, mapEntrySeparator), "custom_client": c.CustomClientName, accessTokenConfig: c.AccessToken, } { @@ -368,22 +371,68 @@ func newConn(dsn string) (*Conn, error) { } for k, v := range map[string]string{ - trinoUserHeader: user, - trinoSourceHeader: query.Get("source"), - trinoCatalogHeader: query.Get("catalog"), - trinoSchemaHeader: query.Get("schema"), - trinoSessionHeader: query.Get("session_properties"), - trinoExtraCredentialHeader: query.Get("extra_credentials"), - authorizationHeader: getAuthorization(query.Get(accessTokenConfig)), + trinoUserHeader: user, + trinoSourceHeader: query.Get("source"), + trinoCatalogHeader: query.Get("catalog"), + trinoSchemaHeader: query.Get("schema"), + authorizationHeader: getAuthorization(query.Get(accessTokenConfig)), } { if v != "" { c.httpHeaders.Add(k, v) } } + for header, param := range map[string]string{ + trinoSessionHeader: "session_properties", + trinoExtraCredentialHeader: "extra_credentials", + } { + v := query.Get(param) + if v != "" { + c.httpHeaders[header], err = decodeMapHeader(param, v) + if err != nil { + return c, err + } + } + } return c, nil } +func decodeMapHeader(name, input string) ([]string, error) { + result := []string{} + for _, entry := range strings.Split(input, mapEntrySeparator) { + parts := strings.SplitN(entry, mapKeySeparator, 2) + if len(parts) != 2 { + return nil, fmt.Errorf("trino: Malformed %s: %s", name, input) + } + key := parts[0] + value := parts[1] + if len(key) == 0 { + return nil, fmt.Errorf("trino: %s key is empty", name) + } + if len(value) == 0 { + return nil, fmt.Errorf("trino: %s value is empty", name) + } + if !isASCII(key) { + return nil, fmt.Errorf("trino: %s key '%s' contains spaces or is not printable ASCII", name, key) + } + if !isASCII(value) { + // do not log value as it may contain sensitive information + return nil, fmt.Errorf("trino: %s value for key '%s' contains spaces or is not printable ASCII", name, key) + } + result = append(result, key+"="+url.QueryEscape(value)) + } + return result, nil +} + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] < '\u0021' || s[i] > '\u007E' { + return false + } + } + return true +} + func getAuthorization(token string) string { if token == "" { return "" diff --git a/trino/trino_test.go b/trino/trino_test.go index 502a749..42f08f5 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -43,7 +43,7 @@ func TestConfig(t *testing.T) { dsn, err := c.FormatDSN() require.NoError(t, err) - want := "http://foobar@localhost:8080?session_properties=query_priority%3D1&source=trino-go-client" + want := "http://foobar@localhost:8080?session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } @@ -58,7 +58,7 @@ func TestConfigSSLCertPath(t *testing.T) { dsn, err := c.FormatDSN() require.NoError(t, err) - want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3D1&source=trino-go-client" + want := "https://foobar@localhost:8080?SSLCertPath=cert.pem&session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } @@ -105,7 +105,7 @@ FKu5ZAlRfb2aYegr49DHhzoVAdInWQmP+5EZEUD1 dsn, err := c.FormatDSN() require.NoError(t, err) - want := "https://foobar@localhost:8080?SSLCert=" + url.QueryEscape(sslCert) + "&session_properties=query_priority%3D1&source=trino-go-client" + want := "https://foobar@localhost:8080?SSLCert=" + url.QueryEscape(sslCert) + "&session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } @@ -113,17 +113,61 @@ FKu5ZAlRfb2aYegr49DHhzoVAdInWQmP+5EZEUD1 func TestExtraCredentials(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8080", - ExtraCredentials: map[string]string{"token": "mYtOkEn", "otherToken": "oThErToKeN"}, + ExtraCredentials: map[string]string{"token": "mYtOkEn", "otherToken": "oThErToKeN%*!#@special"}, } dsn, err := c.FormatDSN() require.NoError(t, err) - want := "http://foobar@localhost:8080?extra_credentials=otherToken%3DoThErToKeN%2Ctoken%3DmYtOkEn&source=trino-go-client" - + want := "http://foobar@localhost:8080?extra_credentials=otherToken%3AoThErToKeN%25%2A%21%23%40special%3Btoken%3AmYtOkEn&source=trino-go-client" assert.Equal(t, want, dsn) } +func TestInvalidExtraCredentials(t *testing.T) { + testcases := []struct { + Name string + Credentials map[string]string + Error string + }{ + { + Name: "Empty key", + Credentials: map[string]string{"": "emptyKey"}, + Error: "trino: extra_credentials key is empty", + }, + { + Name: "Empty value", + Credentials: map[string]string{"valid": "a", "emptyValue": ""}, + Error: "trino: extra_credentials value is empty", + }, + { + Name: "Unprintable key", + Credentials: map[string]string{"😊": "unprintableKey"}, + Error: "trino: extra_credentials key '😊' contains spaces or is not printable ASCII", + }, + { + Name: "Unprintable value", + Credentials: map[string]string{"unprintableValue": "😊"}, + Error: "trino: extra_credentials value for key 'unprintableValue' contains spaces or is not printable ASCII", + }, + } + + for _, tc := range testcases { + + t.Run(tc.Name, func(t *testing.T) { + c := &Config{ + ServerURI: "http://foobar@localhost:8080", + ExtraCredentials: tc.Credentials, + } + dsn, err := c.FormatDSN() + require.NoError(t, err) + db, err := sql.Open("trino", dsn) + require.NoError(t, err) + err = db.Ping() + assert.EqualError(t, err, tc.Error) + }) + } +} + func TestConfigWithoutSSLCertPath(t *testing.T) { c := &Config{ ServerURI: "https://foobar@localhost:8080", @@ -132,7 +176,7 @@ func TestConfigWithoutSSLCertPath(t *testing.T) { dsn, err := c.FormatDSN() require.NoError(t, err) - want := "https://foobar@localhost:8080?session_properties=query_priority%3D1&source=trino-go-client" + want := "https://foobar@localhost:8080?session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) } @@ -153,7 +197,7 @@ func TestKerberosConfig(t *testing.T) { dsn, err := c.FormatDSN() require.NoError(t, err) - want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&KerberosRemoteServiceName=service&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3D1&source=trino-go-client" + want := "https://foobar@localhost:8090?KerberosConfigPath=%2Fetc%2Fkrb5.conf&KerberosEnabled=true&KerberosKeytabPath=%2Fopt%2Ftest.keytab&KerberosPrincipal=trino%2Ftesthost&KerberosRealm=example.com&KerberosRemoteServiceName=service&SSLCertPath=%2Ftmp%2Ftest.cert&session_properties=query_priority%3A1&source=trino-go-client" assert.Equal(t, want, dsn) }