Skip to content

Commit

Permalink
Merge pull request #229 from neutron-org/feat/msg-dispatcher-patch
Browse files Browse the repository at this point in the history
Feat: optimized message dispatcher
  • Loading branch information
pr0n00gler authored Aug 21, 2024
2 parents eafa99c + 22f2771 commit 465a2a8
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 5 deletions.
14 changes: 14 additions & 0 deletions x/wasm/keeper/handler_plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,20 @@ func (h SDKMessageHandler) handleSdkMessage(ctx sdk.Context, contractAddr sdk.Ad
return nil, errorsmod.Wrapf(sdkerrors.ErrUnknownRequest, "can't route message %+v", msg)
}

type callDepthMessageHandler struct {
Messenger
MaxCallDepth uint32
}

func (h callDepthMessageHandler) DispatchMsg(ctx sdk.Context, contractAddr sdk.AccAddress, contractIBCPortID string, msg wasmvmtypes.CosmosMsg) (events []sdk.Event, data [][]byte, msgResponses [][]*codectypes.Any, err error) {
ctx, err = checkAndIncreaseCallDepth(ctx, h.MaxCallDepth)
if err != nil {
return nil, nil, nil, err
}

return h.Messenger.DispatchMsg(ctx, contractAddr, contractIBCPortID, msg)
}

// MessageHandlerChain defines a chain of handlers that are called one by one until it can be handled.
type MessageHandlerChain struct {
handlers []Messenger
Expand Down
20 changes: 20 additions & 0 deletions x/wasm/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ type Keeper struct {
queryGasLimit uint64
gasRegister types.GasRegister
maxQueryStackSize uint32
maxCallDepth uint32
acceptedAccountTypes map[reflect.Type]struct{}
accountPruner AccountPruner
params collections.Item[types.Params]
Expand Down Expand Up @@ -785,6 +786,7 @@ func (k Keeper) mustGetLastContractHistoryEntry(ctx context.Context, contractAdd
// QuerySmart queries the smart contract itself.
func (k Keeper) QuerySmart(ctx context.Context, contractAddr sdk.AccAddress, req []byte) ([]byte, error) {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "query-smart")

// checks and increase query stack size
sdkCtx, err := checkAndIncreaseQueryStackSize(sdk.UnwrapSDKContext(ctx), k.maxQueryStackSize)
if err != nil {
Expand Down Expand Up @@ -832,6 +834,24 @@ func checkAndIncreaseQueryStackSize(ctx context.Context, maxQueryStackSize uint3
return types.WithQueryStackSize(sdk.UnwrapSDKContext(ctx), queryStackSize), nil
}

func checkAndIncreaseCallDepth(ctx context.Context, maxCallDepth uint32) (sdk.Context, error) {
var callDepth uint32 = 0
if size, ok := types.CallDepth(ctx); ok {
callDepth = size
}

// increase
callDepth++

// did we go too far?
if callDepth > maxCallDepth {
return sdk.Context{}, types.ErrExceedMaxCallDepth
}

// set updated stack size
return types.WithCallDepth(sdk.UnwrapSDKContext(ctx), callDepth), nil
}

// QueryRaw returns the contract's state for give key. Returns `nil` when key is `nil`.
func (k Keeper) QueryRaw(ctx context.Context, contractAddress sdk.AccAddress, key []byte) []byte {
defer telemetry.MeasureSince(time.Now(), "wasm", "contract", "query-raw")
Expand Down
3 changes: 3 additions & 0 deletions x/wasm/keeper/keeper_cgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ func NewKeeper(
queryGasLimit: wasmConfig.SmartQueryGasLimit,
gasRegister: types.NewDefaultWasmGasRegister(),
maxQueryStackSize: types.DefaultMaxQueryStackSize,
maxCallDepth: types.DefaultMaxCallDepth,
acceptedAccountTypes: defaultAcceptedAccountTypes,
params: collections.NewItem(sb, types.ParamsKey, "params", codec.CollValue[types.Params](cdc)),
propagateGovAuthorization: map[types.AuthorizationPolicyAction]struct{}{
Expand All @@ -63,6 +64,8 @@ func NewKeeper(
for _, o := range preOpts {
o.apply(keeper)
}
// always wrap the messenger, even if it was replaced by an option
keeper.messenger = callDepthMessageHandler{keeper.messenger, keeper.maxCallDepth}
// only set the wasmvm if no one set this in the options
// NewVM does a lot, so better not to create it and silently drop it.
if keeper.wasmVM == nil {
Expand Down
6 changes: 6 additions & 0 deletions x/wasm/keeper/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ func WithMaxQueryStackSize(m uint32) Option {
})
}

func WithMaxCallDepth(m uint32) Option {
return optsFn(func(k *Keeper) {
k.maxCallDepth = m
})
}

// WithAcceptedAccountTypesOnContractInstantiation sets the accepted account types. Account types of this list won't be overwritten or cause a failure
// when they exist for an address on contract instantiation.
//
Expand Down
14 changes: 11 additions & 3 deletions x/wasm/keeper/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ func TestConstructorOptions(t *testing.T) {
"message handler": {
srcOpt: WithMessageHandler(&wasmtesting.MockMessageHandler{}),
verify: func(t *testing.T, k Keeper) {
assert.IsType(t, &wasmtesting.MockMessageHandler{}, k.messenger)
require.IsType(t, callDepthMessageHandler{}, k.messenger)
messenger, _ := k.messenger.(callDepthMessageHandler)
assert.IsType(t, &wasmtesting.MockMessageHandler{}, messenger.Messenger)
},
},
"query plugins": {
Expand All @@ -70,7 +72,7 @@ func TestConstructorOptions(t *testing.T) {
},
"message handler decorator": {
srcOpt: WithMessageHandlerDecorator(func(old Messenger) Messenger {
require.IsType(t, &MessageHandlerChain{}, old)
require.IsType(t, callDepthMessageHandler{}, old)
return &wasmtesting.MockMessageHandler{}
}),
verify: func(t *testing.T, k Keeper) {
Expand Down Expand Up @@ -108,12 +110,18 @@ func TestConstructorOptions(t *testing.T) {
assert.Equal(t, uint64(2), costCanonical)
},
},
"max recursion query limit": {
"max query recursion limit": {
srcOpt: WithMaxQueryStackSize(1),
verify: func(t *testing.T, k Keeper) {
assert.IsType(t, uint32(1), k.maxQueryStackSize)
},
},
"max message recursion limit": {
srcOpt: WithMaxCallDepth(1),
verify: func(t *testing.T, k Keeper) {
assert.IsType(t, uint32(1), k.maxCallDepth)
},
},
"accepted account types": {
srcOpt: WithAcceptedAccountTypesOnContractInstantiation(&authtypes.BaseAccount{}, &vestingtypes.ContinuousVestingAccount{}),
verify: func(t *testing.T, k Keeper) {
Expand Down
2 changes: 1 addition & 1 deletion x/wasm/keeper/query_plugins_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ func TestQueryErrors(t *testing.T) {
return nil, spec.src
})
ms := store.NewCommitMultiStore(dbm.NewMemDB(), log.NewTestLogger(t), storemetrics.NewNoOpMetrics())
ctx := sdk.Context{}.WithGasMeter(storetypes.NewInfiniteGasMeter()).WithMultiStore(ms).WithLogger(log.NewTestLogger(t))
ctx := sdk.NewContext(ms, cmtproto.Header{}, false, log.NewTestLogger(t)).WithGasMeter(storetypes.NewInfiniteGasMeter())
q := keeper.NewQueryHandler(ctx, mock, sdk.AccAddress{}, types.NewDefaultWasmGasRegister())
_, gotErr := q.Query(wasmvmtypes.QueryRequest{}, 1)
assert.Equal(t, spec.expErr, gotErr)
Expand Down
11 changes: 11 additions & 0 deletions x/wasm/types/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ const (
contextKeySubMsgAuthzPolicy = iota
// gas register
contextKeyGasRegister = iota

contextKeyCallDepth contextKey = iota
)

// WithTXCounter stores a transaction counter value in the context
Expand All @@ -43,6 +45,15 @@ func QueryStackSize(ctx context.Context) (uint32, bool) {
return val, ok
}

func WithCallDepth(ctx sdk.Context, counter uint32) sdk.Context {
return ctx.WithValue(contextKeyCallDepth, counter)
}

func CallDepth(ctx context.Context) (uint32, bool) {
val, ok := ctx.Value(contextKeyCallDepth).(uint32)
return val, ok
}

// WithSubMsgAuthzPolicy stores the authorization policy for submessages into the context returned
func WithSubMsgAuthzPolicy(ctx sdk.Context, policy AuthorizationPolicy) sdk.Context {
if policy == nil {
Expand Down
3 changes: 3 additions & 0 deletions x/wasm/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ var (

// ErrVMError means an error occurred in wasmvm (not in the contract itself, but in the host environment)
ErrVMError = errorsmod.Register(DefaultCodespace, 29, "wasmvm error")

// ErrExceedMaxCallDepth error if max message stack size is exceeded
ErrExceedMaxCallDepth = errorsmod.Register(DefaultCodespace, 30, "max call depth exceeded")
)

// WasmVMErrorable mapped error type in wasmvm and are not redacted
Expand Down
4 changes: 3 additions & 1 deletion x/wasm/types/wasmer_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
storetypes "cosmossdk.io/store/types"
)

// DefaultMaxQueryStackSize maximum size of the stack of contract instances doing queries
// DefaultMaxQueryStackSize maximum size of the stack of recursive queries a contract can make
const DefaultMaxQueryStackSize uint32 = 10

const DefaultMaxCallDepth uint32 = 500

// WasmEngine defines the WASM contract runtime engine.
type WasmEngine interface {
// StoreCode will compile the Wasm code, and store the resulting compiled module
Expand Down

0 comments on commit 465a2a8

Please sign in to comment.