diff --git a/src/test/ERC4626.t.sol b/src/test/ERC4626.t.sol index 816c8e48..2a79545d 100644 --- a/src/test/ERC4626.t.sol +++ b/src/test/ERC4626.t.sol @@ -331,49 +331,57 @@ contract ERC4626Test is DSTestPlus { assertEq(underlying.balanceOf(address(vault)), 0); } - function testFailDepositWithNotEnoughApproval() public { + function testRevertDepositWithNotEnoughApproval() public { underlying.mint(address(this), 0.5e18); underlying.approve(address(vault), 0.5e18); assertEq(underlying.allowance(address(this), address(vault)), 0.5e18); + hevm.expectRevert(); vault.deposit(1e18, address(this)); } - function testFailWithdrawWithNotEnoughUnderlyingAmount() public { + function testRevertWithdrawWithNotEnoughUnderlyingAmount() public { underlying.mint(address(this), 0.5e18); underlying.approve(address(vault), 0.5e18); vault.deposit(0.5e18, address(this)); + hevm.expectRevert(); vault.withdraw(1e18, address(this), address(this)); } - function testFailRedeemWithNotEnoughShareAmount() public { + function testRevertRedeemWithNotEnoughShareAmount() public { underlying.mint(address(this), 0.5e18); underlying.approve(address(vault), 0.5e18); vault.deposit(0.5e18, address(this)); + hevm.expectRevert(); vault.redeem(1e18, address(this), address(this)); } - function testFailWithdrawWithNoUnderlyingAmount() public { + function testRevertWithdrawWithNoUnderlyingAmount() public { + hevm.expectRevert(); vault.withdraw(1e18, address(this), address(this)); } - function testFailRedeemWithNoShareAmount() public { + function testRevertRedeemWithNoShareAmount() public { + hevm.expectRevert(); vault.redeem(1e18, address(this), address(this)); } - function testFailDepositWithNoApproval() public { + function testRevertDepositWithNoApproval() public { + hevm.expectRevert(); vault.deposit(1e18, address(this)); } - function testFailMintWithNoApproval() public { + function testRevertMintWithNoApproval() public { + hevm.expectRevert(); vault.mint(1e18, address(this)); } - function testFailDepositZero() public { + function testRevertDepositZero() public { + hevm.expectRevert("ZERO_SHARES"); vault.deposit(0, address(this)); } @@ -386,7 +394,8 @@ contract ERC4626Test is DSTestPlus { assertEq(vault.totalAssets(), 0); } - function testFailRedeemZero() public { + function testRevertRedeemZero() public { + hevm.expectRevert("ZERO_ASSETS"); vault.redeem(0, address(this), address(this)); } @@ -443,4 +452,30 @@ contract ERC4626Test is DSTestPlus { assertEq(vault.balanceOf(bob), 0); assertEq(underlying.balanceOf(alice), 1e18); } + + function testRevertMintZeroAssetsWhenVaultDrained(uint128 initialDeposit, uint128 massiveMint) public { + if (initialDeposit == 0) initialDeposit = 1; + if (massiveMint == 0) massiveMint = 1; + + address alice = address(0xABCD); + address attacker = address(0xBAD); + + underlying.mint(alice, initialDeposit); + hevm.prank(alice); + underlying.approve(address(vault), initialDeposit); + hevm.prank(alice); + vault.deposit(initialDeposit, alice); + + // Vault is drained (simulate slashing or hack) + underlying.burn(address(vault), initialDeposit); + + // Attacker mints massive amount of shares for free + underlying.mint(attacker, 0); // Attacker has 0 tokens + hevm.prank(attacker); + underlying.approve(address(vault), 0); + + hevm.expectRevert("ZERO_ASSETS"); + hevm.prank(attacker); + vault.mint(massiveMint, attacker); // Should fail due to ZERO_ASSETS check + } } diff --git a/src/test/utils/Hevm.sol b/src/test/utils/Hevm.sol index 8ca0eff9..86cd088a 100644 --- a/src/test/utils/Hevm.sol +++ b/src/test/utils/Hevm.sol @@ -58,6 +58,9 @@ interface Hevm { /// @notice Sets an address' code. function etch(address, bytes calldata) external; + /// @notice Expects a revert from the next call. + function expectRevert() external; + /// @notice Expects an error from the next call. function expectRevert(bytes calldata) external; diff --git a/src/tokens/ERC4626.sol b/src/tokens/ERC4626.sol index 0a34ac98..cb7b38b7 100644 --- a/src/tokens/ERC4626.sol +++ b/src/tokens/ERC4626.sol @@ -60,6 +60,8 @@ abstract contract ERC4626 is ERC20 { function mint(uint256 shares, address receiver) public virtual returns (uint256 assets) { assets = previewMint(shares); // No need to check for rounding error, previewMint rounds up. + require(assets != 0, "ZERO_ASSETS"); + // Need to transfer before minting or ERC777s could reenter. asset.safeTransferFrom(msg.sender, address(this), assets); diff --git a/src/utils/SafeTransferLib.sol b/src/utils/SafeTransferLib.sol index 7f8236db..cab09dc2 100644 --- a/src/utils/SafeTransferLib.sol +++ b/src/utils/SafeTransferLib.sol @@ -42,8 +42,8 @@ library SafeTransferLib { // Write the abi-encoded calldata into memory, beginning with the function selector. mstore(freeMemoryPointer, 0x23b872dd00000000000000000000000000000000000000000000000000000000) - mstore(add(freeMemoryPointer, 4), and(from, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "from" argument. - mstore(add(freeMemoryPointer, 36), and(to, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "to" argument. + mstore(add(freeMemoryPointer, 4), from) // Append the "from" argument. + mstore(add(freeMemoryPointer, 36), to) // Append the "to" argument. mstore(add(freeMemoryPointer, 68), amount) // Append the "amount" argument. Masking not required as it's a full 32 byte type. // We use 100 because the length of our calldata totals up like so: 4 + 32 * 3. @@ -74,7 +74,7 @@ library SafeTransferLib { // Write the abi-encoded calldata into memory, beginning with the function selector. mstore(freeMemoryPointer, 0xa9059cbb00000000000000000000000000000000000000000000000000000000) - mstore(add(freeMemoryPointer, 4), and(to, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "to" argument. + mstore(add(freeMemoryPointer, 4), to) // Append the "to" argument. mstore(add(freeMemoryPointer, 36), amount) // Append the "amount" argument. Masking not required as it's a full 32 byte type. // We use 68 because the length of our calldata totals up like so: 4 + 32 * 2. @@ -105,7 +105,7 @@ library SafeTransferLib { // Write the abi-encoded calldata into memory, beginning with the function selector. mstore(freeMemoryPointer, 0x095ea7b300000000000000000000000000000000000000000000000000000000) - mstore(add(freeMemoryPointer, 4), and(to, 0xffffffffffffffffffffffffffffffffffffffff)) // Append and mask the "to" argument. + mstore(add(freeMemoryPointer, 4), to) // Append the "to" argument. mstore(add(freeMemoryPointer, 36), amount) // Append the "amount" argument. Masking not required as it's a full 32 byte type. // We use 68 because the length of our calldata totals up like so: 4 + 32 * 2.