diff --git a/src/enforcers/ReturnValueComparisonEnforcer.sol b/src/enforcers/ReturnValueComparisonEnforcer.sol new file mode 100644 index 00000000..6ae9b2df --- /dev/null +++ b/src/enforcers/ReturnValueComparisonEnforcer.sol @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT AND Apache-2.0 +pragma solidity 0.8.23; + +import { CaveatEnforcer } from "./CaveatEnforcer.sol"; +import { ModeCode } from "../utils/Types.sol"; + +enum ComparisonOperator { + EQ, // Equal (bytes32 hash equality) + NEQ, // Not Equal (bytes32 hash inequality) + GTE, // Greater Than or Equal (numeric, supports uint/int/bool) + LTE // Less Than or Equal (numeric, supports uint/int/bool) + +} + +enum ValueType { + UINT256, + INT256, + UINT128, + INT128, + BOOL +} + +/** + * @title ReturnValueComparisonEnforcer + * @notice Enforces that the return value of a staticcall matches a comparison against a specified term. + * @dev The `_terms` parameter encodes the target, calldata, comparison operator, and value to compare against. + * For EQ/NEQ, compares keccak256 hashes of the return and expected value (works for structs, tuples, etc). + * For GTE/LTE, supports uint256, int256, uint128, int128, and bool (by type length). + * Example use case: Only allow execution if a collateral ratio is below a threshold. + * + * _terms encoding: abi.encode(target (address), callData (bytes), operator (uint8), expectedValue (bytes)) + */ +contract ReturnValueComparisonEnforcer is CaveatEnforcer { + /** + * @notice Checks that the return value of a staticcall matches the comparison. + * @dev Expects _terms = abi.encode(target, callData, operator, expectedValue) + */ + function beforeHook( + bytes calldata _terms, + bytes calldata, + ModeCode _mode, + bytes calldata, + bytes32, + address, + address + ) + public + view + override + onlyDefaultExecutionMode(_mode) + { + (address target, bytes memory callData, ComparisonOperator op, ValueType typeTag, bytes memory expected) = + abi.decode(_terms, (address, bytes, ComparisonOperator, ValueType, bytes)); + + (bool success, bytes memory result) = target.staticcall(callData); + require(success, "ReturnValueComparisonEnforcer:staticcall-failed"); + + if (op == ComparisonOperator.EQ) { + require(keccak256(result) == keccak256(expected), "not-equal"); + } else if (op == ComparisonOperator.NEQ) { + require(keccak256(result) != keccak256(expected), "equal"); + } else if (op == ComparisonOperator.GTE || op == ComparisonOperator.LTE) { + require(result.length == expected.length, "length-mismatch"); + if (typeTag == ValueType.UINT256) { + uint256 actual = abi.decode(result, (uint256)); + uint256 exp = abi.decode(expected, (uint256)); + require(op == ComparisonOperator.GTE ? actual >= exp : actual <= exp, op == ComparisonOperator.GTE ? "lt" : "gt"); + } else if (typeTag == ValueType.INT256) { + int256 actual = abi.decode(result, (int256)); + int256 exp = abi.decode(expected, (int256)); + require(op == ComparisonOperator.GTE ? actual >= exp : actual <= exp, op == ComparisonOperator.GTE ? "lt" : "gt"); + } else if (typeTag == ValueType.UINT128) { + uint128 actual = abi.decode(result, (uint128)); + uint128 exp = abi.decode(expected, (uint128)); + require(op == ComparisonOperator.GTE ? actual >= exp : actual <= exp, op == ComparisonOperator.GTE ? "lt" : "gt"); + } else if (typeTag == ValueType.INT128) { + int128 actual = abi.decode(result, (int128)); + int128 exp = abi.decode(expected, (int128)); + require(op == ComparisonOperator.GTE ? actual >= exp : actual <= exp, op == ComparisonOperator.GTE ? "lt" : "gt"); + } else if (typeTag == ValueType.BOOL) { + bool actual = abi.decode(result, (bool)); + bool exp = abi.decode(expected, (bool)); + require( + op == ComparisonOperator.GTE ? (actual == exp || (actual && !exp)) : (actual == exp || (!actual && exp)), + op == ComparisonOperator.GTE ? "lt" : "gt" + ); + } else { + revert("unsupported-type"); + } + } else { + revert("invalid-operator"); + } + } +} diff --git a/test/enforcers/ReturnValueComparisonEnforcer.t.sol b/test/enforcers/ReturnValueComparisonEnforcer.t.sol new file mode 100644 index 00000000..626640ec --- /dev/null +++ b/test/enforcers/ReturnValueComparisonEnforcer.t.sol @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT AND Apache-2.0 +pragma solidity 0.8.23; + +import "forge-std/Test.sol"; +import { ModeCode } from "../../src/utils/Types.sol"; +import { ReturnValueComparisonEnforcer, ComparisonOperator, ValueType } from "../../src/enforcers/ReturnValueComparisonEnforcer.sol"; + +contract DummyReader { + uint256 public value; + int256 public ivalue; + uint128 public u128value; + int128 public i128value; + bool public bvalue; + + function set(uint256 v) external { + value = v; + } + + function seti(int256 v) external { + ivalue = v; + } + + function setu128(uint128 v) external { + u128value = v; + } + + function seti128(int128 v) external { + i128value = v; + } + + function setb(bool v) external { + bvalue = v; + } + + function read() external view returns (uint256) { + return value; + } + + function readi() external view returns (int256) { + return ivalue; + } + + function readu128() external view returns (uint128) { + return u128value; + } + + function readi128() external view returns (int128) { + return i128value; + } + + function readb() external view returns (bool) { + return bvalue; + } +} + +contract ReturnValueComparisonEnforcerTest is Test { + ReturnValueComparisonEnforcer public enforcer; + DummyReader public reader; + + function setUp() public { + enforcer = new ReturnValueComparisonEnforcer(); + reader = new DummyReader(); + } + + function testEQ_uint256() public { + reader.set(42); + bytes memory callData = abi.encodeWithSignature("read()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.EQ, ValueType.UINT256, abi.encode(uint256(42))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testNEQ_uint256() public { + reader.set(43); + bytes memory callData = abi.encodeWithSignature("read()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.NEQ, ValueType.UINT256, abi.encode(uint256(42))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testGTE_uint256() public { + reader.set(100); + bytes memory callData = abi.encodeWithSignature("read()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.GTE, ValueType.UINT256, abi.encode(uint256(42))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testLTE_uint256() public { + reader.set(10); + bytes memory callData = abi.encodeWithSignature("read()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.LTE, ValueType.UINT256, abi.encode(uint256(42))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testEQ_bool() public { + reader.setb(true); + bytes memory callData = abi.encodeWithSignature("readb()"); + bytes memory terms = abi.encode(address(reader), callData, ComparisonOperator.EQ, ValueType.BOOL, abi.encode(true)); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testGTE_int256() public { + reader.seti(-1); + bytes memory callData = abi.encodeWithSignature("readi()"); + bytes memory terms = abi.encode(address(reader), callData, ComparisonOperator.GTE, ValueType.INT256, abi.encode(int256(-2))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testLTE_int256() public { + reader.seti(-5); + bytes memory callData = abi.encodeWithSignature("readi()"); + bytes memory terms = abi.encode(address(reader), callData, ComparisonOperator.LTE, ValueType.INT256, abi.encode(int256(-2))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testEQ_uint128() public { + reader.setu128(123); + bytes memory callData = abi.encodeWithSignature("readu128()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.EQ, ValueType.UINT128, abi.encode(uint128(123))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function testGTE_int128() public { + reader.seti128(-10); + bytes memory callData = abi.encodeWithSignature("readi128()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.GTE, ValueType.INT128, abi.encode(int128(-20))); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function test_RevertWhen_NEQ_uint256_shouldRevert() public { + reader.set(42); + bytes memory callData = abi.encodeWithSignature("read()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.NEQ, ValueType.UINT256, abi.encode(uint256(42))); + vm.expectRevert("equal"); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } + + function test_RevertWhen_GTE_uint256_shouldRevert() public { + reader.set(10); + bytes memory callData = abi.encodeWithSignature("read()"); + bytes memory terms = + abi.encode(address(reader), callData, ComparisonOperator.GTE, ValueType.UINT256, abi.encode(uint256(42))); + vm.expectRevert(bytes("lt")); + enforcer.beforeHook(terms, "", ModeCode.wrap(bytes32(0)), "", bytes32(0), address(0), address(0)); + } +}