diff --git a/encoding/json.go b/encoding/json.go index e35c507..85156ba 100644 --- a/encoding/json.go +++ b/encoding/json.go @@ -19,7 +19,7 @@ import ( ) func jsonReGetField(key, s, catch string) (string, error) { - r := fmt.Sprintf(`%q:%s`, key, catch) + r := fmt.Sprintf(`%q:%s`, regexp.QuoteMeta(key), catch) re := regexp.MustCompile(r) matches := re.FindStringSubmatch(s) @@ -30,15 +30,9 @@ func jsonReGetField(key, s, catch string) (string, error) { return matches[1], nil } -// JSONReGetGroup attempts to find the group JSON encoding in s. The optional key argument overrides the default key the -// regex will use to look for the group. -func JSONReGetGroup(s string, key ...string) (ecc.Group, error) { - reKey := "group" - if len(key) != 0 && key[0] != "" { - reKey = key[0] - } - - f, err := jsonReGetField(reKey, s, `(\w+)`) +// JSONReGetGroup attempts to find the group JSON encoding in s. +func JSONReGetGroup(s string) (ecc.Group, error) { + f, err := jsonReGetField("group", s, `(\w+)`) if err != nil { return 0, err } diff --git a/groups.go b/groups.go index 90098fc..7398ffa 100644 --- a/groups.go +++ b/groups.go @@ -60,7 +60,6 @@ const ( var ( once [maxID - 1]sync.Once groups [maxID - 1]internal.Group - errInvalidID = errors.New("invalid group identifier") errZeroLenDST = errors.New("zero-length DST") ) @@ -71,7 +70,7 @@ func (g Group) Available() bool { func (g Group) get() internal.Group { if !g.Available() { - panic(errInvalidID) + panic(internal.ErrInvalidGroup) } once[g-1].Do(g.init) diff --git a/tests/encoding_test.go b/tests/encoding_test.go index 0204636..527550a 100644 --- a/tests/encoding_test.go +++ b/tests/encoding_test.go @@ -288,28 +288,5 @@ func TestJSONReGetGroup(t *testing.T) { if g != group.group { t.Fatal(errExpectedEquality) } - - // with another key - test2 := struct { - Group ecc.Group `json:"ciphersuite"` - Int int `json:"int"` - }{ - Group: group.group, - Int: 1, - } - - enc, err = json.Marshal(test2) - if err != nil { - t.Fatal(err) - } - - g, err = eccEncoding.JSONReGetGroup(string(enc), "ciphersuite") - if err != nil { - t.Fatal(err) - } - - if g != group.group { - t.Fatal(errExpectedEquality) - } }) } diff --git a/tests/fuzz_test.go b/tests/fuzz_test.go new file mode 100644 index 0000000..77aec66 --- /dev/null +++ b/tests/fuzz_test.go @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// +// Copyright (C) 2020-2024 Daniel Bourdrez. All Rights Reserved. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree or at +// https://spdx.org/licenses/MIT.html + +package ecc_test + +import ( + "testing" + + "github.com/bytemare/ecc" + "github.com/bytemare/ecc/encoding" + "github.com/bytemare/ecc/internal" +) + +func FuzzGroup(f *testing.F) { + f.Fuzz(func(t *testing.T, group byte, h2Input, h2DST []byte, dstApp string, dstVersion uint8) { + if panicked, err := hasPanic(func() { + g := ecc.Group(group) + + if len(g.MakeDST(dstApp, dstVersion)) == 0 { + t.Fatal("unexpected 0 length dst") + } + + if len(h2DST) != 0 { + one := g.NewScalar().SetUInt64(1) + if s := g.HashToScalar(h2Input, h2DST); s.IsZero() || s.Equal(one) { + t.Fatal("HashToScalar yielded 0 or 1") + } + + if e := g.HashToGroup(h2Input, h2DST); e.IsIdentity() || e.Equal(g.Base()) { + t.Fatal("HashToGroup yielded identity or generator") + } + + if e := g.EncodeToGroup(h2Input, h2DST); e.IsIdentity() || e.Equal(g.Base()) { + t.Fatal("HashToGroup yielded identity or generator") + } + } + }); panicked && err.Error() != internal.ErrInvalidGroup.Error() { + t.Fatal(err) + } + }) +} + +func FuzzScalar(f *testing.F) { + f.Fuzz(func(t *testing.T, group byte, input []byte, i uint64) { + if panicked, err := hasPanic(func() { + g := ecc.Group(group) + s := g.NewScalar() + + s.SetUInt64(i) + _ = s.Decode(input) + _ = s.DecodeHex(string(input)) + _ = s.UnmarshalJSON(input) + _ = s.UnmarshalBinary(input) + }); panicked && err.Error() != internal.ErrInvalidGroup.Error() { + t.Fatal(err) + } + }) +} + +func FuzzElement(f *testing.F) { + f.Fuzz(func(t *testing.T, group byte, input []byte) { + if panicked, err := hasPanic(func() { + g := ecc.Group(group) + s := g.NewScalar() + + _ = s.Decode(input) + _ = s.DecodeHex(string(input)) + _ = s.UnmarshalJSON(input) + _ = s.UnmarshalBinary(input) + }); panicked && err.Error() != internal.ErrInvalidGroup.Error() { + t.Fatal(err) + } + }) +} + +func FuzzJSONReGetGroup(f *testing.F) { + f.Fuzz(func(t *testing.T, input string) { + _, _ = encoding.JSONReGetGroup(input) + }) +} diff --git a/tests/groups_test.go b/tests/groups_test.go index 9e2b95d..1d45722 100644 --- a/tests/groups_test.go +++ b/tests/groups_test.go @@ -10,11 +10,11 @@ package ecc_test import ( "encoding/hex" - "errors" "fmt" "testing" "github.com/bytemare/ecc" + "github.com/bytemare/ecc/internal" ) const consideredAvailableFmt = "%v is considered available when it must not" @@ -28,8 +28,6 @@ func TestAvailability(t *testing.T) { } func TestNonAvailability(t *testing.T) { - errInvalidID := errors.New("invalid group identifier") - oob := ecc.Group(0) if oob.Available() { t.Errorf(consideredAvailableFmt, oob) @@ -40,7 +38,7 @@ func TestNonAvailability(t *testing.T) { t.Errorf(consideredAvailableFmt, d) } - if err := testPanic("decaf availability", errInvalidID, + if err := testPanic("decaf availability", internal.ErrInvalidGroup, func() { _ = d.String() }); err != nil { t.Fatal(err) } @@ -50,13 +48,13 @@ func TestNonAvailability(t *testing.T) { t.Errorf(consideredAvailableFmt, oob) } - if err := testPanic("oob availability", errInvalidID, + if err := testPanic("oob availability", internal.ErrInvalidGroup, func() { _ = oob.String() }); err != nil { t.Fatal(err) } oob++ - if err := testPanic("oob availability", errInvalidID, + if err := testPanic("oob availability", internal.ErrInvalidGroup, func() { _ = oob.String() }); err != nil { t.Fatal(err) } diff --git a/tests/utils_test.go b/tests/utils_test.go index f6b626d..11e2f80 100644 --- a/tests/utils_test.go +++ b/tests/utils_test.go @@ -24,6 +24,8 @@ var ( errWrapGroup = "%s: %w" ) +// hasPanic runs f and recovers from a panic if any occurred, and returns whether it did and the panic message as an +// error. func hasPanic(f func()) (has bool, err error) { defer func() { var report any