Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 512bits add and mult operations #5035

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/blue-nails-give.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Math`: Add `add512`, `mul512` and `mulShr`.
58 changes: 49 additions & 9 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,34 @@ library Math {
Expand // Away from zero
}

/**
* @dev Return the 512-bit addition of two uint256.
*
* The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
*/
function add512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
assembly ("memory-safe") {
low := add(a, b)
high := lt(low, a)
}
}

/**
* @dev Return the 512-bit multiplication of two uint256.
*
* The result is stored in two 256 variables such that product = high * 2²⁵⁶ + low.
*/
function mul512(uint256 a, uint256 b) internal pure returns (uint256 high, uint256 low) {
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
// the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
// variables such that product = prod1 * 2²⁵⁶ + prod0.
assembly ("memory-safe") {
let mm := mulmod(a, b, not(0))
low := mul(a, b)
high := sub(sub(mm, low), lt(mm, low))
}
}

/**
* @dev Returns the addition of two unsigned integers, with an success flag (no overflow).
*/
Expand Down Expand Up @@ -143,15 +171,7 @@ library Math {
*/
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
unchecked {
// 512-bit multiply [prod1 prod0] = x * y. Compute the product mod 2²⁵⁶ and mod 2²⁵⁶ - 1, then use
// the Chinese Remainder Theorem to reconstruct the 512 bit result. The result is stored in two 256
// variables such that product = prod1 * 2²⁵⁶ + prod0.
uint256 prod0 = x * y; // Least significant 256 bits of the product
uint256 prod1; // Most significant 256 bits of the product
assembly {
let mm := mulmod(x, y, not(0))
prod1 := sub(sub(mm, prod0), lt(mm, prod0))
}
(uint256 prod1, uint256 prod0) = mul512(x, y);

// Handle non-overflow cases, 256 by 256 division.
if (prod1 == 0) {
Expand Down Expand Up @@ -229,6 +249,26 @@ library Math {
return mulDiv(x, y, denominator) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, denominator) > 0);
}

/**
* @dev Calculates floor(x * y >> n) with full precision. Throws if result overflows a uint256.
*/
function mulShr(uint256 x, uint256 y, uint8 n) internal pure returns (uint256 result) {
unchecked {
(uint256 prod1, uint256 prod0) = mul512(x, y);
if (prod1 >= 1 << n) {
Panic.panic(Panic.UNDER_OVERFLOW);
}
return (prod1 << (256 - n)) | (prod0 >> n);
}
}

/**
* @dev Calculates x * y >> n with full precision, following the selected rounding direction.
*/
function mulShr(uint256 x, uint256 y, uint8 n, Rounding rounding) internal pure returns (uint256) {
return mulShr(x, y, n) + SafeCast.toUint(unsignedRoundsUp(rounding) && mulmod(x, y, 1 << n) > 0);
}

/**
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
*
Expand Down
12 changes: 7 additions & 5 deletions test/helpers/enums.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
function Enum(...options) {
return Object.fromEntries(options.map((key, i) => [key, BigInt(i)]));
}
const { ethers } = require('ethers');

const Enum = (...options) => Object.fromEntries(options.map((key, i) => [key, BigInt(i)]));
const EnumTyped = (...options) => Object.fromEntries(options.map((key, i) => [key, ethers.Typed.uint8(i)]));

module.exports = {
Enum,
EnumTyped,
ProposalState: Enum('Pending', 'Active', 'Canceled', 'Defeated', 'Succeeded', 'Queued', 'Expired', 'Executed'),
VoteType: Object.assign(Enum('Against', 'For', 'Abstain'), { Parameters: 255n }),
Rounding: Enum('Floor', 'Ceil', 'Trunc', 'Expand'),
Rounding: EnumTyped('Floor', 'Ceil', 'Trunc', 'Expand'),
OperationState: Enum('Unset', 'Waiting', 'Ready', 'Done'),
RevertType: Enum('None', 'RevertWithoutMessage', 'RevertWithMessage', 'RevertWithCustomError', 'Panic'),
RevertType: EnumTyped('None', 'RevertWithoutMessage', 'RevertWithMessage', 'RevertWithCustomError', 'Panic'),
};
23 changes: 23 additions & 0 deletions test/utils/math/Math.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,29 @@ contract MathTest is Test {
assertEq(Math.ternary(f, a, b), f ? a : b);
}

// ADD512 & MUL512
function testAdd512(uint256 a, uint256 b) public pure {
(uint256 high, uint256 low) = Math.add512(a, b);
(bool success, uint256 result) = Math.tryAdd(a, b);
if (success) {
assertEq(high, 0);
assertEq(low, result);
} else {
assertEq(high, 1);
}
}

function testMul512(uint256 a, uint256 b) public pure {
(uint256 high, uint256 low) = Math.mul512(a, b);
(bool success, uint256 result) = Math.tryMul(a, b);
if (success) {
assertEq(high, 0);
assertEq(low, result);
} else {
assertGt(high, 0);
}
}

// MIN & MAX
function testSymbolicMinMax(uint256 a, uint256 b) public pure {
assertEq(Math.min(a, b), a < b ? a : b);
Expand Down
Loading