Skip to content

Commit

Permalink
Revise the security schema and fix unescape characters (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac authored Nov 17, 2024
1 parent 20f76ab commit cd7fa66
Show file tree
Hide file tree
Showing 61 changed files with 2,114 additions and 3,455 deletions.
3 changes: 0 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ linters:
- err113
- lll
- gocognit
- execinquery
- exportloopref
- gomnd
- funlen
- godot
- gofumpt
Expand All @@ -20,7 +18,6 @@ linters:
- wsl
- wrapcheck
- varnamelen
- nlreturn
- exhaustive
- exhaustruct
- gocyclo
Expand Down
1 change: 1 addition & 0 deletions connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func (c *HTTPConnector) ParseConfiguration(ctx context.Context, configurationDir
schemas, errs = configuration.BuildSchemaFromConfig(config, configurationDir, logger)
if len(errs) > 0 {
printSchemaValidationError(logger, errs)

return nil, errBuildSchemaFailed
}
}
Expand Down
5 changes: 3 additions & 2 deletions connector/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,9 @@ func createMockServer(t *testing.T, apiKey string, bearerToken string) *httptest
mux.HandleFunc("/model", func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodPost:
if r.Header.Get("api_key") != apiKey {
t.Errorf("invalid api key, expected %s, got %s", apiKey, r.Header.Get("api_key"))
user, password, ok := r.BasicAuth()
if !ok || user != "user" || password != "password" {
t.Errorf("invalid basic auth, expected user:password, got %s:%s", user, password)
t.FailNow()
return
}
Expand Down
20 changes: 18 additions & 2 deletions connector/internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"log/slog"
"math"
"net/http"
"net/url"
"slices"
"strconv"
"strings"
Expand Down Expand Up @@ -250,12 +251,21 @@ func (client *HTTPClient) doRequest(ctx context.Context, request *RetryableReque
ctx, span := client.tracer.Start(ctx, fmt.Sprintf("%s %s", method, request.RawRequest.URL), trace.WithSpanKind(trace.SpanKindClient))
defer span.End()

requestURL := request.URL.String()
urlAttr := cloneURL(&request.URL)
password, hasPassword := urlAttr.User.Password()
if urlAttr.User.String() != "" || hasPassword {
maskedUser := MaskString(urlAttr.User.Username())
if hasPassword {
urlAttr.User = url.UserPassword(maskedUser, MaskString(password))
} else {
urlAttr.User = url.User(maskedUser)
}
}

span.SetAttributes(
attribute.String("db.system", "http"),
attribute.String("http.request.method", method),
attribute.String("url.full", requestURL),
attribute.String("url.full", urlAttr.String()),
attribute.String("server.address", request.URL.Hostname()),
attribute.Int("server.port", port),
attribute.String("network.protocol.name", "http"),
Expand Down Expand Up @@ -355,12 +365,14 @@ func evalHTTPResponse(ctx context.Context, span trace.Span, resp *http.Response,
if len(respBody) == 0 {
return nil, resp.Header, nil
}

return string(respBody), resp.Header, nil
case "text/plain", "text/html":
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, schema.NewConnectorError(http.StatusInternalServerError, err.Error(), nil)
}

return string(respBody), resp.Header, nil
case rest.ContentTypeJSON:
if len(resultType) > 0 {
Expand All @@ -378,6 +390,7 @@ func evalHTTPResponse(ctx context.Context, span trace.Span, resp *http.Response,
// fallback to raw string response if the result type is String
return string(respBytes), resp.Header, nil
}

return strResult, resp.Header, nil
}
}
Expand All @@ -395,6 +408,7 @@ func evalHTTPResponse(ctx context.Context, span trace.Span, resp *http.Response,
if err != nil {
return nil, nil, schema.InternalServerError(err.Error(), nil)
}

return result, resp.Header, nil
case rest.ContentTypeNdJSON:
var results []any
Expand All @@ -415,6 +429,7 @@ func evalHTTPResponse(ctx context.Context, span trace.Span, resp *http.Response,
if err != nil {
return nil, nil, schema.InternalServerError(err.Error(), nil)
}

return result, resp.Header, nil
default:
return nil, nil, schema.NewConnectorError(http.StatusInternalServerError, "failed to evaluate response", map[string]any{
Expand All @@ -428,5 +443,6 @@ func parseContentType(input string) string {
return ""
}
parts := strings.Split(input, ";")

return strings.TrimSpace(parts[0])
}
1 change: 1 addition & 0 deletions connector/internal/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func DecodeDataURI(input string) (*DataURI, error) {
if err != nil {
return nil, err
}

return &DataURI{
Data: string(rawDecodedBytes),
}, nil
Expand Down
2 changes: 2 additions & 0 deletions connector/internal/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func (rms MetadataCollection) GetFunction(name string) (*rest.OperationInfo, con
return fn, rm, nil
}
}

return nil, configuration.NDCHttpRuntimeSchema{}, schema.UnprocessableContentError("unsupported query: "+name, nil)
}

Expand All @@ -28,5 +29,6 @@ func (rms MetadataCollection) GetProcedure(name string) (*rest.OperationInfo, co
return fn, rm, nil
}
}

return nil, configuration.NDCHttpRuntimeSchema{}, schema.UnprocessableContentError("unsupported mutation: "+name, nil)
}
3 changes: 3 additions & 0 deletions connector/internal/multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ func (w *MultipartWriter) WriteJSON(fieldName string, value any, headers http.He
}

_, err = p.Write(bs)

return err
}

Expand All @@ -90,6 +91,7 @@ func (w *MultipartWriter) WriteField(fieldName, value string, headers http.Heade
return err
}
_, err = p.Write([]byte(value))

return err
}

Expand All @@ -100,5 +102,6 @@ func createFieldMIMEHeader(fieldName string, headers http.Header) textproto.MIME
}
h.Set("Content-Disposition",
fmt.Sprintf(`form-data; name="%s"`, escapeQuotes(fieldName)))

return h
}
10 changes: 10 additions & 0 deletions connector/internal/parameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ func (ssp ParameterItems) String() string {
str := item.String()
results = append(results, str)
}

return strings.Join(results, "&")
}

func (ssp *ParameterItems) Add(keys []Key, values []string) {
index := ssp.FindIndex(keys)
if index == -1 {
*ssp = append(*ssp, NewParameterItem(keys, values))

return
}
(*ssp)[index].AddValues(values)
Expand All @@ -41,16 +43,19 @@ func (ssp ParameterItems) FindDefault() *ParameterItem {
return item
}
item, _ = ssp.find([]Key{})

return item
}

func (ssp ParameterItems) Find(keys []Key) *ParameterItem {
item, _ := ssp.find(keys)

return item
}

func (ssp ParameterItems) FindIndex(keys []Key) int {
_, i := ssp.find(keys)

return i
}

Expand All @@ -73,6 +78,7 @@ func (ssp ParameterItems) find(keys []Key) (*ParameterItem, int) {
return &item, i
}
}

return nil, -1
}

Expand All @@ -89,6 +95,7 @@ func (ks Keys) String() string {
for i, k := range ks {
if k.index != nil {
sb.WriteString(fmt.Sprintf("[%d]", *k.index))

continue
}
if k.key != "" {
Expand All @@ -98,6 +105,7 @@ func (ks Keys) String() string {
sb.WriteString(k.key)
}
}

return sb.String()
}

Expand Down Expand Up @@ -137,6 +145,7 @@ func (k Key) String() string {
if k.index != nil {
return strconv.Itoa(*k.index)
}

return k.key
}

Expand All @@ -161,6 +170,7 @@ func (ssp ParameterItem) String() string {
if key == "" {
return value
}

return fmt.Sprintf("%s=%s", key, value)
}

Expand Down
Loading

0 comments on commit cd7fa66

Please sign in to comment.