Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add validation for multi message execution wasm #2092

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ and the [related comments from @Unique-Divine and @berndartmueller](https://gith
- [#2084](https://github.com/NibiruChain/nibiru/pull/2084) - feat(evm-forge): foundry support and template for Nibiru EVM develoment
- [#2088](https://github.com/NibiruChain/nibiru/pull/2088) - refactor(evm): remove outdated comment and improper error message text
- [#2089](https://github.com/NibiruChain/nibiru/pull/2089) - better handling of gas consumption within erc20 contract execution
- [#2092](https://github.com/NibiruChain/nibiru/pull/2092) - feat(evm): add validation for wasm multi message execution


#### Dapp modules: perp, spot, oracle, etc
Expand Down
51 changes: 37 additions & 14 deletions x/evm/precompile/wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ func (p precompileWasm) executeMulti(
err = ErrMethodCalled(method, err)
}
}()

if err := assertNotReadonlyTx(readOnly, true); err != nil {
return bz, err
}
Expand All @@ -295,29 +296,51 @@ func (p precompileWasm) executeMulti(
err = ErrInvalidArgs(err)
return
}
callerBech32 := eth.EthAddrToNibiruAddr(caller)

var responses [][]byte
for _, m := range wasmExecMsgs {
wasmContract, e := sdk.AccAddressFromBech32(m.ContractAddr)
// Validate all messages before executing any of them
type validatedMsg struct {
contract sdk.AccAddress
msgArgs []byte
funds sdk.Coins
}

validatedMsgs := make([]validatedMsg, 0, len(wasmExecMsgs))

// Validate each message using parseExecuteArgs
for _, msg := range wasmExecMsgs {
// Create args array in the format expected by parseExecuteArgs
singleMsgArgs := []any{
msg.ContractAddr,
msg.MsgArgs,
msg.Funds,
}

contract, msgArgs, funds, e := p.parseExecuteArgs(singleMsgArgs)
if e != nil {
err = fmt.Errorf("Execute failed: %w", e)
err = fmt.Errorf("validation failed for contract %s: %w", msg.ContractAddr, e)
return
}
var funds sdk.Coins
for _, fund := range m.Funds {
funds = append(funds, sdk.Coin{
Denom: fund.Denom,
Amount: sdk.NewIntFromBigInt(fund.Amount),
})
}
respBz, e := p.Wasm.Execute(ctx, wasmContract, callerBech32, m.MsgArgs, funds)

validatedMsgs = append(validatedMsgs, validatedMsg{
contract: contract,
msgArgs: msgArgs,
funds: funds,
})
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance validation by collecting all errors before returning

Currently, the validation loop returns upon encountering the first error. Collecting all validation errors before returning could provide more comprehensive feedback to the caller.

You might modify the code as follows:

 var validationErrors []error

 for _, msg := range wasmExecMsgs {
     // Create args array in the format expected by parseExecuteArgs
     singleMsgArgs := []any{
         msg.ContractAddr,
         msg.MsgArgs,
         msg.Funds,
     }

     contract, msgArgs, funds, e := p.parseExecuteArgs(singleMsgArgs)
     if e != nil {
-        err = fmt.Errorf("validation failed for contract %s: %w", msg.ContractAddr, e)
-        return
+        validationErrors = append(validationErrors, fmt.Errorf("validation failed for contract %s: %w", msg.ContractAddr, e))
         continue
     }

     validatedMsgs = append(validatedMsgs, validatedMsg{
         contract: contract,
         msgArgs:  msgArgs,
         funds:    funds,
     })
 }

+if len(validationErrors) > 0 {
+    err = fmt.Errorf("validation errors occurred: %v", validationErrors)
+    return
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Validate each message using parseExecuteArgs
for _, msg := range wasmExecMsgs {
// Create args array in the format expected by parseExecuteArgs
singleMsgArgs := []any{
msg.ContractAddr,
msg.MsgArgs,
msg.Funds,
}
contract, msgArgs, funds, e := p.parseExecuteArgs(singleMsgArgs)
if e != nil {
err = fmt.Errorf("Execute failed: %w", e)
err = fmt.Errorf("validation failed for contract %s: %w", msg.ContractAddr, e)
return
}
var funds sdk.Coins
for _, fund := range m.Funds {
funds = append(funds, sdk.Coin{
Denom: fund.Denom,
Amount: sdk.NewIntFromBigInt(fund.Amount),
})
}
respBz, e := p.Wasm.Execute(ctx, wasmContract, callerBech32, m.MsgArgs, funds)
validatedMsgs = append(validatedMsgs, validatedMsg{
contract: contract,
msgArgs: msgArgs,
funds: funds,
})
}
// Validate each message using parseExecuteArgs
var validationErrors []error
for _, msg := range wasmExecMsgs {
// Create args array in the format expected by parseExecuteArgs
singleMsgArgs := []any{
msg.ContractAddr,
msg.MsgArgs,
msg.Funds,
}
contract, msgArgs, funds, e := p.parseExecuteArgs(singleMsgArgs)
if e != nil {
validationErrors = append(validationErrors, fmt.Errorf("validation failed for contract %s: %w", msg.ContractAddr, e))
continue
}
validatedMsgs = append(validatedMsgs, validatedMsg{
contract: contract,
msgArgs: msgArgs,
funds: funds,
})
}
if len(validationErrors) > 0 {
err = fmt.Errorf("validation errors occurred: %v", validationErrors)
return
}


callerBech32 := eth.EthAddrToNibiruAddr(caller)
var responses [][]byte

// Execute all messages after validation
for _, msg := range validatedMsgs {
respBz, e := p.Wasm.Execute(ctx, msg.contract, callerBech32, msg.msgArgs, msg.funds)
if e != nil {
err = e
err = fmt.Errorf("execute failed for contract %s: %w", msg.contract.String(), e)
return
}
responses = append(responses, respBz)
}

return method.Outputs.Pack(responses)
}

Expand Down
186 changes: 186 additions & 0 deletions x/evm/precompile/wasm_test.go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases look good

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/NibiruChain/nibiru/v2/app"
"github.com/NibiruChain/nibiru/v2/x/common/testutil"
"github.com/NibiruChain/nibiru/v2/x/common/testutil/testapp"
"github.com/NibiruChain/nibiru/v2/x/evm/embeds"
"github.com/NibiruChain/nibiru/v2/x/evm/evmtest"
"github.com/NibiruChain/nibiru/v2/x/evm/precompile"
Expand Down Expand Up @@ -585,3 +586,188 @@ func (s *WasmSuite) TestSadArgsExecute() {
})
}
}

type WasmExecuteMsg struct {
ContractAddr string `json:"contractAddr"`
MsgArgs []byte `json:"msgArgs"`
Funds []precompile.WasmBankCoin `json:"funds"`
}

func (s *WasmSuite) TestExecuteMultiValidation() {
deps := evmtest.NewTestDeps()

s.Require().NoError(testapp.FundAccount(
deps.App.BankKeeper,
deps.Ctx,
deps.Sender.NibiruAddr,
sdk.NewCoins(sdk.NewCoin("unibi", sdk.NewInt(100))),
))

wasmContracts := SetupWasmContracts(&deps, &s.Suite)
wasmContract := wasmContracts[1] // hello_world_counter.wasm

invalidMsgArgsBz := []byte(`{"invalid": "json"}`) // Invalid message format
validMsgArgsBz := []byte(`{"increment": {}}`) // Valid increment message

var emptyFunds []precompile.WasmBankCoin
validFunds := []precompile.WasmBankCoin{{
Denom: "unibi",
Amount: big.NewInt(100),
}}
invalidFunds := []precompile.WasmBankCoin{{
Denom: "invalid!denom",
Amount: big.NewInt(100),
}}

testCases := []struct {
name string
executeMsgs []WasmExecuteMsg
wantError string
}{
{
name: "valid - single message",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "",
},
{
name: "valid - multiple messages",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: validFunds,
},
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "",
},
{
name: "invalid - malformed contract address",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: "invalid-address",
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "decoding bech32 failed",
},
{
name: "invalid - malformed message args",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: invalidMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "unknown variant",
},
{
name: "invalid - malformed funds",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: invalidFunds,
},
},
wantError: "invalid coins",
},
{
name: "invalid - second message fails validation",
executeMsgs: []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: validMsgArgsBz,
Funds: emptyFunds,
},
{
ContractAddr: wasmContract.String(),
MsgArgs: invalidMsgArgsBz,
Funds: emptyFunds,
},
},
wantError: "unknown variant",
},
}

for _, tc := range testCases {
s.Run(tc.name, func() {
callArgs := []any{tc.executeMsgs}
input, err := embeds.SmartContract_Wasm.ABI.Pack(
string(precompile.WasmMethod_executeMulti),
callArgs...,
)
s.Require().NoError(err)

ethTxResp, err := deps.EvmKeeper.CallContractWithInput(
deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input,
)

if tc.wantError != "" {
s.Require().Error(err)
s.Require().Contains(err.Error(), tc.wantError)
s.Require().Nil(ethTxResp)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error assertions by checking error types instead of error messages

Comparing error messages using substrings can be brittle and may lead to fragile tests. Consider checking for specific error types or using error wrapping to make your tests more robust and maintainable.

} else {
s.Require().NoError(err)
s.Require().NotNil(ethTxResp)
s.Require().NotEmpty(ethTxResp.Ret)
}
})
}
}

// TestExecuteMultiPartialExecution ensures that no state changes occur if any message
// in the batch fails validation
func (s *WasmSuite) TestExecuteMultiPartialExecution() {
deps := evmtest.NewTestDeps()
wasmContracts := SetupWasmContracts(&deps, &s.Suite)
wasmContract := wasmContracts[1] // hello_world_counter.wasm

// First verify initial state is 0
s.assertWasmCounterState(deps, wasmContract, 0)

// Create a batch where the second message will fail validation
executeMsgs := []WasmExecuteMsg{
{
ContractAddr: wasmContract.String(),
MsgArgs: []byte(`{"increment": {}}`),
Funds: []precompile.WasmBankCoin{},
},
{
ContractAddr: wasmContract.String(),
MsgArgs: []byte(`{"invalid": "json"}`), // This will fail validation
Funds: []precompile.WasmBankCoin{},
},
}

callArgs := []any{executeMsgs}
input, err := embeds.SmartContract_Wasm.ABI.Pack(
string(precompile.WasmMethod_executeMulti),
callArgs...,
)
s.Require().NoError(err)

ethTxResp, err := deps.EvmKeeper.CallContractWithInput(
deps.Ctx, deps.Sender.EthAddr, &precompile.PrecompileAddr_Wasm, true, input,
)

// Verify that the call failed
s.Require().Error(err)
s.Require().Contains(err.Error(), "unknown variant")
s.Require().Nil(ethTxResp)

// Verify that no state changes occurred
s.assertWasmCounterState(deps, wasmContract, 0)
}
Loading