Skip to content

Commit

Permalink
fix: custom Host header. (#59)
Browse files Browse the repository at this point in the history
* fix: custom Host header.
traefik/traefik#6502

* feat: Use utils package to handle header modifications

* review

---------

Co-authored-by: Tom Moulard <tom@moulard.org>
  • Loading branch information
ISNing and tomMoulard authored May 14, 2024
1 parent 3424ef4 commit d3b793a
Show file tree
Hide file tree
Showing 17 changed files with 541 additions and 82 deletions.
3 changes: 2 additions & 1 deletion pkg/handler/deleter/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"net/http"

"github.com/tomMoulard/htransformation/pkg/types"
"github.com/tomMoulard/htransformation/pkg/utils/header"
)

func Validate(types.Rule) error {
Expand All @@ -17,5 +18,5 @@ func Handle(rw http.ResponseWriter, req *http.Request, rule types.Rule) {
return
}

req.Header.Del(rule.Header)
header.Delete(req, rule.Header)
}
29 changes: 21 additions & 8 deletions pkg/handler/deleter/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ func TestDeleteHandler(t *testing.T) {
t.Parallel()

tests := []struct {
name string
rule types.Rule
requestHeaders map[string]string
want map[string]string
name string
rule types.Rule
requestHeaders map[string]string
expectedHeaders map[string]string
expectedHost string
}{
{
name: "Remove not existing header",
Expand All @@ -29,9 +30,10 @@ func TestDeleteHandler(t *testing.T) {
requestHeaders: map[string]string{
"Foo": "Bar",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
},
expectedHost: "example.com",
},
{
name: "Remove one header",
Expand All @@ -42,9 +44,17 @@ func TestDeleteHandler(t *testing.T) {
"Foo": "Bar",
"X-Test": "Bar",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
},
expectedHost: "example.com",
},
{
name: "Remove host header",
rule: types.Rule{
Header: "Host",
},
expectedHost: "",
},
}

Expand All @@ -53,7 +63,7 @@ func TestDeleteHandler(t *testing.T) {
t.Parallel()

ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com/foo", nil)
require.NoError(t, err)

for hName, hVal := range test.requestHeaders {
Expand All @@ -62,9 +72,12 @@ func TestDeleteHandler(t *testing.T) {

deleter.Handle(nil, req, test.rule)

for hName, hVal := range test.want {
for hName, hVal := range test.expectedHeaders {
assert.Equal(t, hVal, req.Header.Get(hName))
}

assert.Equal(t, test.expectedHost, req.Host)
assert.Equal(t, "example.com", req.URL.Host)
})
}
}
Expand Down
27 changes: 22 additions & 5 deletions pkg/handler/join/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,16 @@ func Validate(rule types.Rule) error {
}

func Handle(rw http.ResponseWriter, req *http.Request, rule types.Rule) {
val, ok := req.Header[rule.Header]
if !ok {
return
var val []string
if strings.EqualFold(rule.Header, "Host") {
val = []string{req.Host}
} else {
var ok bool
val, ok = req.Header[rule.Header]

if !ok {
return
}
}

newHeaderVal := val[0]
Expand All @@ -32,7 +39,11 @@ func Handle(rw http.ResponseWriter, req *http.Request, rule types.Rule) {
return
}

req.Header.Set(rule.Header, newHeaderVal)
if strings.EqualFold(rule.Header, "Host") {
req.Host = newHeaderVal
} else {
req.Header.Set(rule.Header, newHeaderVal)
}
}

// getValue checks if prefix exists, the given prefix is present,
Expand All @@ -47,7 +58,13 @@ func getValue(ruleValue, valueIsHeaderPrefix string, req *http.Request) string {
// we return the actual value,
// which is the prefix itself.
// This is because doing a req.Header.Get("") would not fly well.
if header != "" {
if header == "" {
return actualValue
}

if strings.EqualFold(header, "Host") {
actualValue = req.Host
} else {
actualValue = req.Header.Get(header)
}
}
Expand Down
66 changes: 54 additions & 12 deletions pkg/handler/join/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (

func TestJoinHandler(t *testing.T) {
testCases := []struct {
name string
rule types.Rule
requestHeaders map[string]string
want map[string]string
name string
rule types.Rule
requestHeaders map[string]string
expectedHeaders map[string]string
expectedHost string
}{
{
name: "Join two headers simple value",
Expand All @@ -31,10 +32,11 @@ func TestJoinHandler(t *testing.T) {
"Foo": "Bar",
"X-Test": "Bar",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "Bar,Tested",
},
expectedHost: "example.com",
},
{
name: "Join two headers multiple value",
Expand All @@ -51,10 +53,11 @@ func TestJoinHandler(t *testing.T) {
"Foo": "Bar",
"X-Test": "Bar",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "Bar,Tested,Compiled,Working",
},
expectedHost: "example.com",
},
{
name: "Join two headers simple value",
Expand All @@ -72,11 +75,12 @@ func TestJoinHandler(t *testing.T) {
"X-Source": "Tested",
"X-Test": "Bar",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Source": "Tested",
"X-Test": "Bar,Tested",
},
expectedHost: "example.com",
},
{
name: "Join two headers multiple value",
Expand All @@ -97,12 +101,13 @@ func TestJoinHandler(t *testing.T) {
"X-Source-1": "Tested",
"X-Source-3": "Working",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "Bar,Tested,Compiled,Working",
"X-Source-1": "Tested",
"X-Source-3": "Working",
},
expectedHost: "example.com",
},
{
name: "Join two headers multiple value with itself",
Expand All @@ -122,10 +127,11 @@ func TestJoinHandler(t *testing.T) {
"X-Test": "test",
"X-Source-3": "third",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "test,second,test,third",
},
expectedHost: "example.com",
},
{
name: "Join value with same HeaderPrefix",
Expand All @@ -141,10 +147,43 @@ func TestJoinHandler(t *testing.T) {
"Foo": "Bar",
"X-Test": "Bar",
},
want: map[string]string{
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "Bar,Tested",
},
expectedHost: "example.com",
},
{
name: "Join Host header",
rule: types.Rule{
Sep: ",",
Header: "Host",
HeaderPrefix: "Tested",
Values: []string{
"Tested",
},
},
requestHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "Bar",
},
expectedHeaders: map[string]string{
"Foo": "Bar",
"X-Test": "Bar",
},
expectedHost: "example.com,Tested",
},
{
name: "Twice Host header",
rule: types.Rule{
Sep: ",",
Header: "Host",
Values: []string{
"^Host",
},
HeaderPrefix: "^",
},
expectedHost: "example.com,example.com",
},
}

Expand All @@ -153,7 +192,7 @@ func TestJoinHandler(t *testing.T) {
t.Parallel()

ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://example.com/foo", nil)
require.NoError(t, err)

for hName, hVal := range test.requestHeaders {
Expand All @@ -162,9 +201,12 @@ func TestJoinHandler(t *testing.T) {

join.Handle(nil, req, test.rule)

for hName, hVal := range test.want {
for hName, hVal := range test.expectedHeaders {
assert.Equal(t, hVal, req.Header.Get(hName))
}

assert.Equal(t, test.expectedHost, req.Host)
assert.Equal(t, "example.com", req.URL.Host)
})
}
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/handler/rename/rename.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"regexp"

"github.com/tomMoulard/htransformation/pkg/types"
"github.com/tomMoulard/htransformation/pkg/utils/header"
)

func Validate(rule types.Rule) error {
Expand All @@ -21,6 +22,9 @@ func Validate(rule types.Rule) error {
}

func Handle(rw http.ResponseWriter, req *http.Request, rule types.Rule) {
originalHost := req.Header.Get("Host") // Eventually X-Forwarded-Host
req.Header.Set("Host", req.Host)

for headerName, headerValues := range req.Header {
if matched := rule.Regexp.Match([]byte(headerName)); !matched {
continue
Expand All @@ -29,15 +33,17 @@ func Handle(rw http.ResponseWriter, req *http.Request, rule types.Rule) {
if rule.SetOnResponse {
rw.Header().Del(headerName)
} else {
req.Header.Del(headerName)
header.Delete(req, headerName)
}

for _, val := range headerValues {
if rule.SetOnResponse {
rw.Header().Set(rule.Value, val)
} else {
req.Header.Set(rule.Value, val)
header.Set(req, rule.Value, val)
}
}
}

req.Header.Set("Host", originalHost)
}
Loading

0 comments on commit d3b793a

Please sign in to comment.