Skip to content

Commit

Permalink
Merge pull request #248 from goccy/feature/context
Browse files Browse the repository at this point in the history
Support context for MarshalJSON and UnmarshalJSON
  • Loading branch information
goccy authored Jun 12, 2021
2 parents 3c3226e + cd7fb73 commit 5c22860
Show file tree
Hide file tree
Showing 17 changed files with 355 additions and 51 deletions.
43 changes: 41 additions & 2 deletions decode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package json

import (
"context"
"fmt"
"io"
"reflect"
Expand Down Expand Up @@ -39,7 +40,7 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
}
ctx := decoder.TakeRuntimeContext()
ctx.Buf = src
ctx.Option.Flag = 0
ctx.Option.Flags = 0
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
Expand All @@ -52,6 +53,36 @@ func unmarshal(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
return validateEndBuf(src, cursor)
}

func unmarshalContext(ctx context.Context, data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data)

header := (*emptyInterface)(unsafe.Pointer(&v))

if err := validateType(header.typ, uintptr(header.ptr)); err != nil {
return err
}
dec, err := decoder.CompileToGetDecoder(header.typ)
if err != nil {
return err
}
rctx := decoder.TakeRuntimeContext()
rctx.Buf = src
rctx.Option.Flags = 0
rctx.Option.Flags |= decoder.ContextOption
rctx.Option.Context = ctx
for _, optFunc := range optFuncs {
optFunc(rctx.Option)
}
cursor, err := dec.Decode(rctx, 0, 0, header.ptr)
if err != nil {
decoder.ReleaseRuntimeContext(rctx)
return err
}
decoder.ReleaseRuntimeContext(rctx)
return validateEndBuf(src, cursor)
}

func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc) error {
src := make([]byte, len(data)+1) // append nul byte to the end
copy(src, data)
Expand All @@ -68,7 +99,7 @@ func unmarshalNoEscape(data []byte, v interface{}, optFuncs ...DecodeOptionFunc)

ctx := decoder.TakeRuntimeContext()
ctx.Buf = src
ctx.Option.Flag = 0
ctx.Option.Flags = 0
for _, optFunc := range optFuncs {
optFunc(ctx.Option)
}
Expand Down Expand Up @@ -137,6 +168,14 @@ func (d *Decoder) Decode(v interface{}) error {
return d.DecodeWithOption(v)
}

// DecodeContext reads the next JSON-encoded value from its
// input and stores it in the value pointed to by v with context.Context.
func (d *Decoder) DecodeContext(ctx context.Context, v interface{}) error {
d.s.Option.Flags |= decoder.ContextOption
d.s.Option.Context = ctx
return d.DecodeWithOption(v)
}

func (d *Decoder) DecodeWithOption(v interface{}, optFuncs ...DecodeOptionFunc) error {
header := (*emptyInterface)(unsafe.Pointer(&v))
typ := header.typ
Expand Down
46 changes: 46 additions & 0 deletions decode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package json_test

import (
"bytes"
"context"
"encoding"
stdjson "encoding/json"
"errors"
Expand Down Expand Up @@ -3620,3 +3621,48 @@ func TestDecodeEscapedCharField(t *testing.T) {
}
})
}

type unmarshalContextKey struct{}

type unmarshalContextStructType struct {
v int
}

func (t *unmarshalContextStructType) UnmarshalJSON(ctx context.Context, b []byte) error {
v := ctx.Value(unmarshalContextKey{})
s, ok := v.(string)
if !ok {
return fmt.Errorf("failed to propagate parent context.Context")
}
if s != "hello" {
return fmt.Errorf("failed to propagate parent context.Context")
}
t.v = 100
return nil
}

func TestDecodeContextOption(t *testing.T) {
src := []byte("10")
buf := bytes.NewBuffer(src)

t.Run("UnmarshalContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
var v unmarshalContextStructType
if err := json.UnmarshalContext(ctx, src, &v); err != nil {
t.Fatal(err)
}
if v.v != 100 {
t.Fatal("failed to decode with context")
}
})
t.Run("DecodeContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), unmarshalContextKey{}, "hello")
var v unmarshalContextStructType
if err := json.NewDecoder(buf).DecodeContext(ctx, &v); err != nil {
t.Fatal(err)
}
if v.v != 100 {
t.Fatal("failed to decode with context")
}
})
}
43 changes: 42 additions & 1 deletion encode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package json

import (
"context"
"io"
"unsafe"

Expand Down Expand Up @@ -35,15 +36,28 @@ func (e *Encoder) Encode(v interface{}) error {
// EncodeWithOption call Encode with EncodeOption.
func (e *Encoder) EncodeWithOption(v interface{}, optFuncs ...EncodeOptionFunc) error {
ctx := encoder.TakeRuntimeContext()
ctx.Option.Flag = 0

err := e.encodeWithOption(ctx, v, optFuncs...)

encoder.ReleaseRuntimeContext(ctx)
return err
}

// EncodeContext call Encode with context.Context and EncodeOption.
func (e *Encoder) EncodeContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) error {
rctx := encoder.TakeRuntimeContext()
rctx.Option.Flag = 0
rctx.Option.Flag |= encoder.ContextOption
rctx.Option.Context = ctx

err := e.encodeWithOption(rctx, v, optFuncs...)

encoder.ReleaseRuntimeContext(rctx)
return err
}

func (e *Encoder) encodeWithOption(ctx *encoder.RuntimeContext, v interface{}, optFuncs ...EncodeOptionFunc) error {
ctx.Option.Flag = 0
if e.enabledHTMLEscape {
ctx.Option.Flag |= encoder.HTMLEscapeOption
}
Expand Down Expand Up @@ -94,6 +108,33 @@ func (e *Encoder) SetIndent(prefix, indent string) {
e.enabledIndent = true
}

func marshalContext(ctx context.Context, v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
rctx := encoder.TakeRuntimeContext()
rctx.Option.Flag = 0
rctx.Option.Flag = encoder.HTMLEscapeOption | encoder.ContextOption
rctx.Option.Context = ctx
for _, optFunc := range optFuncs {
optFunc(rctx.Option)
}

buf, err := encode(rctx, v)
if err != nil {
encoder.ReleaseRuntimeContext(rctx)
return nil, err
}

// this line exists to escape call of `runtime.makeslicecopy` .
// if use `make([]byte, len(buf)-1)` and `copy(copied, buf)`,
// dst buffer size and src buffer size are differrent.
// in this case, compiler uses `runtime.makeslicecopy`, but it is slow.
buf = buf[:len(buf)-1]
copied := make([]byte, len(buf))
copy(copied, buf)

encoder.ReleaseRuntimeContext(rctx)
return copied, nil
}

func marshal(v interface{}, optFuncs ...EncodeOptionFunc) ([]byte, error) {
ctx := encoder.TakeRuntimeContext()

Expand Down
40 changes: 40 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package json_test

import (
"bytes"
"context"
"encoding"
stdjson "encoding/json"
"errors"
Expand Down Expand Up @@ -1918,3 +1919,42 @@ func TestEncodeMapKeyTypeInterface(t *testing.T) {
t.Fatal("expected error")
}
}

type marshalContextKey struct{}

type marshalContextStructType struct{}

func (t *marshalContextStructType) MarshalJSON(ctx context.Context) ([]byte, error) {
v := ctx.Value(marshalContextKey{})
s, ok := v.(string)
if !ok {
return nil, fmt.Errorf("failed to propagate parent context.Context")
}
if s != "hello" {
return nil, fmt.Errorf("failed to propagate parent context.Context")
}
return []byte(`"success"`), nil
}

func TestEncodeContextOption(t *testing.T) {
t.Run("MarshalContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
b, err := json.MarshalContext(ctx, &marshalContextStructType{})
if err != nil {
t.Fatal(err)
}
if string(b) != `"success"` {
t.Fatal("failed to encode with MarshalerContext")
}
})
t.Run("EncodeContext", func(t *testing.T) {
ctx := context.WithValue(context.Background(), marshalContextKey{}, "hello")
buf := bytes.NewBuffer([]byte{})
if err := json.NewEncoder(buf).EncodeContext(ctx, &marshalContextStructType{}); err != nil {
t.Fatal(err)
}
if buf.String() != "\"success\"\n" {
t.Fatal("failed to encode with EncodeContext")
}
})
}
10 changes: 7 additions & 3 deletions internal/decoder/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func compileToGetDecoderSlowPath(typeptr uintptr, typ *runtime.Type) (Decoder, e

func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
switch {
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), "", ""), nil
case runtime.PtrTo(typ).Implements(unmarshalTextType):
return newUnmarshalTextDecoder(runtime.PtrTo(typ), "", ""), nil
Expand All @@ -70,7 +70,7 @@ func compileHead(typ *runtime.Type, structTypeToDecoder map[uintptr]Decoder) (De

func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecoder map[uintptr]Decoder) (Decoder, error) {
switch {
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
return newUnmarshalJSONDecoder(runtime.PtrTo(typ), structName, fieldName), nil
case runtime.PtrTo(typ).Implements(unmarshalTextType):
return newUnmarshalTextDecoder(runtime.PtrTo(typ), structName, fieldName), nil
Expand Down Expand Up @@ -133,7 +133,7 @@ func compile(typ *runtime.Type, structName, fieldName string, structTypeToDecode

func isStringTagSupportedType(typ *runtime.Type) bool {
switch {
case runtime.PtrTo(typ).Implements(unmarshalJSONType):
case implementsUnmarshalJSONType(runtime.PtrTo(typ)):
return false
case runtime.PtrTo(typ).Implements(unmarshalTextType):
return false
Expand Down Expand Up @@ -494,3 +494,7 @@ func compileStruct(typ *runtime.Type, structName, fieldName string, structTypeTo
structDec.tryOptimize()
return structDec, nil
}

func implementsUnmarshalJSONType(typ *runtime.Type) bool {
return typ.Implements(unmarshalJSONType) || typ.Implements(unmarshalJSONContextType)
}
38 changes: 38 additions & 0 deletions internal/decoder/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ func decodeStreamUnmarshaler(s *Stream, depth int64, unmarshaler json.Unmarshale
return nil
}

func decodeStreamUnmarshalerContext(s *Stream, depth int64, unmarshaler unmarshalerContext) error {
start := s.cursor
if err := s.skipValue(depth); err != nil {
return err
}
src := s.buf[start:s.cursor]
dst := make([]byte, len(src))
copy(dst, src)

if err := unmarshaler.UnmarshalJSON(s.Option.Context, dst); err != nil {
return err
}
return nil
}

func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarshaler) (int64, error) {
cursor = skipWhiteSpace(buf, cursor)
start := cursor
Expand All @@ -134,6 +149,23 @@ func decodeUnmarshaler(buf []byte, cursor, depth int64, unmarshaler json.Unmarsh
return end, nil
}

func decodeUnmarshalerContext(ctx *RuntimeContext, buf []byte, cursor, depth int64, unmarshaler unmarshalerContext) (int64, error) {
cursor = skipWhiteSpace(buf, cursor)
start := cursor
end, err := skipValue(buf, cursor, depth)
if err != nil {
return 0, err
}
src := buf[start:end]
dst := make([]byte, len(src))
copy(dst, src)

if err := unmarshaler.UnmarshalJSON(ctx.Option.Context, dst); err != nil {
return 0, err
}
return end, nil
}

func decodeStreamTextUnmarshaler(s *Stream, depth int64, unmarshaler encoding.TextUnmarshaler, p unsafe.Pointer) error {
start := s.cursor
if err := s.skipValue(depth); err != nil {
Expand Down Expand Up @@ -260,6 +292,9 @@ func (d *interfaceDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer
}))
rv := reflect.ValueOf(runtimeInterfaceValue)
if rv.NumMethod() > 0 && rv.CanInterface() {
if u, ok := rv.Interface().(unmarshalerContext); ok {
return decodeStreamUnmarshalerContext(s, depth, u)
}
if u, ok := rv.Interface().(json.Unmarshaler); ok {
return decodeStreamUnmarshaler(s, depth, u)
}
Expand Down Expand Up @@ -317,6 +352,9 @@ func (d *interfaceDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p un
}))
rv := reflect.ValueOf(runtimeInterfaceValue)
if rv.NumMethod() > 0 && rv.CanInterface() {
if u, ok := rv.Interface().(unmarshalerContext); ok {
return decodeUnmarshalerContext(ctx, buf, cursor, depth, u)
}
if u, ok := rv.Interface().(json.Unmarshaler); ok {
return decodeUnmarshaler(buf, cursor, depth, u)
}
Expand Down
10 changes: 7 additions & 3 deletions internal/decoder/option.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package decoder

type OptionFlag int
import "context"

type OptionFlags uint8

const (
FirstWinOption OptionFlag = 1 << iota
FirstWinOption OptionFlags = 1 << iota
ContextOption
)

type Option struct {
Flag OptionFlag
Flags OptionFlags
Context context.Context
}
4 changes: 2 additions & 2 deletions internal/decoder/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ func (d *structDecoder) DecodeStream(s *Stream, depth int64, p unsafe.Pointer) e
seenFields map[int]struct{}
seenFieldNum int
)
firstWin := (s.Option.Flag & FirstWinOption) != 0
firstWin := (s.Option.Flags & FirstWinOption) != 0
if firstWin {
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
}
Expand Down Expand Up @@ -752,7 +752,7 @@ func (d *structDecoder) Decode(ctx *RuntimeContext, cursor, depth int64, p unsaf
seenFields map[int]struct{}
seenFieldNum int
)
firstWin := (ctx.Option.Flag & FirstWinOption) != 0
firstWin := (ctx.Option.Flags & FirstWinOption) != 0
if firstWin {
seenFields = make(map[int]struct{}, d.fieldUniqueNameNum)
}
Expand Down
Loading

0 comments on commit 5c22860

Please sign in to comment.