Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trie: [wip] reduce allocations in derivesha #30747

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
14 changes: 7 additions & 7 deletions core/types/hashing.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ func prefixedRlpHash(prefix byte, x interface{}) (h common.Hash) {
type TrieHasher interface {
Reset()
Update([]byte, []byte) error
// UpdateSafe is identical to Update, except that this method will copy the
// value slice. The caller is free to modify the value bytes after this method returns.
UpdateSafe([]byte, []byte) error
Hash() common.Hash
}

Expand All @@ -95,10 +98,7 @@ type DerivableList interface {
func encodeForDerive(list DerivableList, i int, buf *bytes.Buffer) []byte {
buf.Reset()
list.EncodeIndex(i, buf)
// It's really unfortunate that we need to perform this copy.
// StackTrie holds onto the values until Hash is called, so the values
// written to it must not alias.
return common.CopyBytes(buf.Bytes())
return buf.Bytes()
}

// DeriveSha creates the tree hashes of transactions, receipts, and withdrawals in a block header.
Expand All @@ -118,17 +118,17 @@ func DeriveSha(list DerivableList, hasher TrieHasher) common.Hash {
for i := 1; i < list.Len() && i <= 0x7f; i++ {
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
value := encodeForDerive(list, i, valueBuf)
hasher.Update(indexBuf, value)
hasher.UpdateSafe(indexBuf, value)
}
if list.Len() > 0 {
indexBuf = rlp.AppendUint64(indexBuf[:0], 0)
value := encodeForDerive(list, 0, valueBuf)
hasher.Update(indexBuf, value)
hasher.UpdateSafe(indexBuf, value)
}
for i := 0x80; i < list.Len(); i++ {
indexBuf = rlp.AppendUint64(indexBuf[:0], uint64(i))
value := encodeForDerive(list, i, valueBuf)
hasher.Update(indexBuf, value)
hasher.UpdateSafe(indexBuf, value)
}
return hasher.Hash()
}
25 changes: 18 additions & 7 deletions core/types/hashing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,31 @@ func BenchmarkDeriveSha200(b *testing.B) {
if err != nil {
b.Fatal(err)
}
var exp common.Hash
var got common.Hash
want := types.DeriveSha(txs, trie.NewEmpty(triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)))
var have common.Hash
b.Run("std_trie", func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
exp = types.DeriveSha(txs, trie.NewEmpty(triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)))
have = types.DeriveSha(txs, trie.NewEmpty(triedb.NewDatabase(rawdb.NewMemoryDatabase(), nil)))
}
if have != want {
b.Errorf("have %x want %x", have, want)
}
})

st := trie.NewStackTrie(nil)
b.Run("stack_trie", func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
got = types.DeriveSha(txs, trie.NewStackTrie(nil))
st.Reset()
have = types.DeriveSha(txs, st)
}
if have != want {
b.Errorf("have %x want %x", have, want)
}
})
if got != exp {
b.Errorf("got %x exp %x", got, exp)
}
}

func TestFuzzDeriveSha(t *testing.T) {
Expand Down Expand Up @@ -226,6 +231,12 @@ func (d *hashToHumanReadable) Update(i []byte, i2 []byte) error {
return nil
}

// UpdateSafe is identical to Update, except that this method will copy the
// value slice. The caller is free to modify the value bytes after this method returns.
func (d *hashToHumanReadable) UpdateSafe(key, value []byte) error {
return d.Update(key, common.CopyBytes(value))
}

func (d *hashToHumanReadable) Hash() common.Hash {
return common.Hash{}
}
18 changes: 15 additions & 3 deletions core/types/tx_blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package types

import (
"crypto/ecdsa"
"sync"
"testing"

"github.com/ethereum/go-ethereum/common"
Expand Down Expand Up @@ -58,19 +59,30 @@ func TestBlobTxSize(t *testing.T) {
}
}

// emptyInit ensures that we init the kzg empties only once
var (
emptyBlob = new(kzg4844.Blob)
emptyBlobCommit, _ = kzg4844.BlobToCommitment(emptyBlob)
emptyBlobProof, _ = kzg4844.ComputeBlobProof(emptyBlob, emptyBlobCommit)
emptyInit sync.Once
emptyBlob *kzg4844.Blob
emptyBlobCommit kzg4844.Commitment
emptyBlobProof kzg4844.Proof
)

func initEmpties() {
emptyInit.Do(func() {
emptyBlob = new(kzg4844.Blob)
emptyBlobCommit, _ = kzg4844.BlobToCommitment(emptyBlob)
emptyBlobProof, _ = kzg4844.ComputeBlobProof(emptyBlob, emptyBlobCommit)
})
}

func createEmptyBlobTx(key *ecdsa.PrivateKey, withSidecar bool) *Transaction {
blobtx := createEmptyBlobTxInner(withSidecar)
signer := NewCancunSigner(blobtx.ChainID.ToBig())
return MustSignNewTx(key, signer, blobtx)
}

func createEmptyBlobTxInner(withSidecar bool) *BlobTx {
initEmpties()
sidecar := &BlobTxSidecar{
Blobs: []kzg4844.Blob{*emptyBlob},
Commitments: []kzg4844.Commitment{emptyBlobCommit},
Expand Down
6 changes: 6 additions & 0 deletions internal/blocktest/test_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ func (h *testHasher) Update(key, val []byte) error {
return nil
}

// UpdateSafe is identical to Update, except that this method will copy the
// value slice. The caller is free to modify the value bytes after this method returns.
func (h *testHasher) UpdateSafe(key, value []byte) error {
return h.Update(key, common.CopyBytes(value))
}

// Hash returns the hash value.
func (h *testHasher) Hash() common.Hash {
return common.BytesToHash(h.hasher.Sum(nil))
Expand Down
92 changes: 92 additions & 0 deletions trie/bytepool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright 2024 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package trie

// bytesPool is a pool for byteslices. It is safe for concurrent use.
type bytesPool struct {
c chan []byte
w int
}

// newBytesPool creates a new bytesPool. The sliceCap sets the capacity of
// newly allocated slices, and the nitems determines how many items the pool
// will hold, at maximum.
func newBytesPool(sliceCap, nitems int) *bytesPool {
return &bytesPool{
c: make(chan []byte, nitems),
w: sliceCap,
}
}

// Get returns a slice. Safe for concurrent use.
func (bp *bytesPool) Get() []byte {
select {
case b := <-bp.c:
return b
default:
return make([]byte, 0, bp.w)
}
}

// Put returns a slice to the pool. Safe for concurrent use. This method
// will ignore slices that are too small or too large (>3x the cap)
func (bp *bytesPool) Put(b []byte) {
if c := cap(b); c < bp.w || c > 3*bp.w {
return
}
select {
case bp.c <- b:
default:
}
}

// unsafeBytesPool is a pool for byteslices. It is not safe for concurrent use.
type unsafeBytesPool struct {
items [][]byte
w int
}

// newUnsafeBytesPool creates a new bytesPool. The sliceCap sets the capacity of
// newly allocated slices, and the nitems determines how many items the pool
// will hold, at maximum.
func newUnsafeBytesPool(sliceCap, nitems int) *unsafeBytesPool {
return &unsafeBytesPool{
items: make([][]byte, 0, nitems),
w: sliceCap,
}
}

// Get returns a slice.
func (bp *unsafeBytesPool) Get() []byte {
if len(bp.items) > 0 {
last := bp.items[len(bp.items)-1]
bp.items = bp.items[:len(bp.items)-1]
return last
}
return make([]byte, 0, bp.w)
}

// Put returns a slice to the pool. This method
// will ignore slices that are too small or too large (>3x the cap)
func (bp *unsafeBytesPool) Put(b []byte) {
if c := cap(b); c < bp.w || c > 3*bp.w {
return
}
if len(bp.items) < cap(bp.items) {
bp.items = append(bp.items, b)
}
}
11 changes: 11 additions & 0 deletions trie/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ func keybytesToHex(str []byte) []byte {
return nibbles
}

// writeHexKey writes the hexkey into the given slice.
// OBS! This method omits the termination flag.
// OBS! The dst slice must be at least 2x as large as the key
func writeHexKey(dst []byte, key []byte) {
_ = dst[2*len(key)-1]
for i, b := range key {
dst[i*2] = b / 16
dst[i*2+1] = b % 16
}
}

// hexToKeybytes turns hex nibbles into key bytes.
// This can only be used for keys of even length.
func hexToKeybytes(hex []byte) []byte {
Expand Down
8 changes: 8 additions & 0 deletions trie/hasher.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ func (h *hasher) hashData(data []byte) hashNode {
return n
}

// hashDataTo hashes the provided data to the given destination buffer. The caller
// must ensure that the dst buffer is of appropriate size.
func (h *hasher) hashDataTo(dst, data []byte) {
h.sha.Reset()
h.sha.Write(data)
h.sha.Read(dst)
}

// proofHash is used to construct trie proofs, and returns the 'collapsed'
// node (for later RLP encoding) as well as the hashed node -- unless the
// node is smaller than 32 bytes, in which case it will be returned as is.
Expand Down
52 changes: 44 additions & 8 deletions trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ type (
}
hashNode []byte
valueNode []byte

//fullnodeEncoder is a type used exclusively for encoding. Briefly instantiating
// a fullnodeEncoder and initializing with existing slices is less memory
// intense than using the fullNode type.
fullnodeEncoder struct {
Children [17][]byte
flags nodeFlag
}

//shortNodeEncoder is a type used exclusively for encoding. Briefly instantiating
// a shortNodeEncoder and initializing with existing slices is less memory
// intense than using the shortNode type.
shortNodeEncoder struct {
Key []byte
Val []byte
flags nodeFlag
}
)

// nilValueNode is used when collapsing internal trie nodes for hashing, since
Expand All @@ -67,16 +84,20 @@ type nodeFlag struct {
dirty bool // whether the node has changes that must be written to the database
}

func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n hashNode) cache() (hashNode, bool) { return nil, true }
func (n valueNode) cache() (hashNode, bool) { return nil, true }
func (n *fullNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n *fullnodeEncoder) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n *shortNodeEncoder) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
func (n hashNode) cache() (hashNode, bool) { return nil, true }
func (n valueNode) cache() (hashNode, bool) { return nil, true }

// Pretty printing.
func (n *fullNode) String() string { return n.fstring("") }
func (n *shortNode) String() string { return n.fstring("") }
func (n hashNode) String() string { return n.fstring("") }
func (n valueNode) String() string { return n.fstring("") }
func (n *fullNode) String() string { return n.fstring("") }
func (n *fullnodeEncoder) String() string { return n.fstring("") }
func (n *shortNode) String() string { return n.fstring("") }
func (n *shortNodeEncoder) String() string { return n.fstring("") }
func (n hashNode) String() string { return n.fstring("") }
func (n valueNode) String() string { return n.fstring("") }

func (n *fullNode) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind)
Expand All @@ -89,9 +110,24 @@ func (n *fullNode) fstring(ind string) string {
}
return resp + fmt.Sprintf("\n%s] ", ind)
}

func (n *fullnodeEncoder) fstring(ind string) string {
resp := fmt.Sprintf("[\n%s ", ind)
for i, node := range &n.Children {
if node == nil {
resp += fmt.Sprintf("%s: <nil> ", indices[i])
} else {
resp += fmt.Sprintf("%s: %x", indices[i], node)
}
}
return resp + fmt.Sprintf("\n%s] ", ind)
}
func (n *shortNode) fstring(ind string) string {
return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+" "))
}
func (n *shortNodeEncoder) fstring(ind string) string {
return fmt.Sprintf("{%x: %x} ", n.Key, n.Val)
}
func (n hashNode) fstring(ind string) string {
return fmt.Sprintf("<%x> ", []byte(n))
}
Expand Down
28 changes: 28 additions & 0 deletions trie/node_enc.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ func (n *fullNode) encode(w rlp.EncoderBuffer) {
w.ListEnd(offset)
}

func (n *fullnodeEncoder) encode(w rlp.EncoderBuffer) {
offset := w.List()
for _, c := range n.Children {
if c == nil {
w.Write(rlp.EmptyString)
} else if len(c) < 32 {
w.Write(c) // rawNode
} else {
w.WriteBytes(c) // hashNode
}
}
w.ListEnd(offset)
}

func (n *shortNode) encode(w rlp.EncoderBuffer) {
offset := w.List()
w.WriteBytes(n.Key)
Expand All @@ -51,6 +65,20 @@ func (n *shortNode) encode(w rlp.EncoderBuffer) {
w.ListEnd(offset)
}

func (n *shortNodeEncoder) encode(w rlp.EncoderBuffer) {
offset := w.List()
w.WriteBytes(n.Key)

if n.Val == nil {
w.Write(rlp.EmptyString)
} else if len(n.Val) < 32 {
w.Write(n.Val) // rawNode
} else {
w.WriteBytes(n.Val) // hashNode
}
w.ListEnd(offset)
}

func (n hashNode) encode(w rlp.EncoderBuffer) {
w.WriteBytes(n)
}
Expand Down
Loading
Loading