Skip to content

Commit

Permalink
Merge pull request #4115 from randombit/jack/ct-choice
Browse files Browse the repository at this point in the history
Add CT::Choice
  • Loading branch information
randombit authored Jun 12, 2024
2 parents 856174e + 3e8815c commit 927aab8
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 55 deletions.
85 changes: 41 additions & 44 deletions src/lib/math/pcurves/pcurves_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,15 @@ class IntMod final {
}
}

constexpr bool is_zero() const { return CT::all_zeros(m_val.data(), m_val.size()).as_bool(); }
constexpr CT::Choice is_zero() const { return CT::all_zeros(m_val.data(), m_val.size()).as_choice(); }

constexpr bool is_nonzero() const { return !is_zero(); }
constexpr CT::Choice is_nonzero() const { return !is_zero(); }

constexpr bool is_one() const { return (*this == Self::one()); }
constexpr CT::Choice is_one() const { return (*this == Self::one()); }

constexpr bool is_even() const {
constexpr CT::Choice is_even() const {
auto v = Rep::from_rep(m_val);
return (v[0] & 0x01) == 0;
return CT::Choice::from_int(v[0] & 0x01);
}

friend constexpr Self operator+(const Self& a, const Self& b) {
Expand Down Expand Up @@ -167,8 +167,8 @@ class IntMod final {
}

// if cond is true, assigns other to *this
constexpr void conditional_assign(bool cond, const Self& other) {
CT::conditional_assign_mem(static_cast<W>(cond), m_val.data(), other.data(), N);
constexpr void conditional_assign(CT::Choice cond, const Self& other) {
CT::conditional_assign_mem(cond, m_val.data(), other.data(), N);
}

constexpr Self square() const {
Expand Down Expand Up @@ -246,26 +246,26 @@ class IntMod final {
requires(Self::P_MOD_4 == 3)
{
auto z = pow_vartime(Self::P_PLUS_1_OVER_4);
const bool correct = (z.square() == *this);
const CT::Choice correct = (z.square() == *this);
z.conditional_assign(!correct, Self::zero());
return z;
}

constexpr bool is_square() const
constexpr CT::Choice is_square() const
requires(Self::P_MOD_4 == 3)
{
auto z = pow_vartime(Self::P_MINUS_1_OVER_2);
const bool is_one = z.is_one();
const bool is_zero = z.is_zero();
const CT::Choice is_one = z.is_one();
const CT::Choice is_zero = z.is_zero();
return (is_one || is_zero);
}

constexpr bool operator==(const Self& other) const {
return CT::is_equal(this->data(), other.data(), N).as_bool();
constexpr CT::Choice operator==(const Self& other) const {
return CT::is_equal(this->data(), other.data(), N).as_choice();
}

constexpr bool operator!=(const Self& other) const {
return CT::is_not_equal(this->data(), other.data(), N).as_bool();
constexpr CT::Choice operator!=(const Self& other) const {
return CT::is_not_equal(this->data(), other.data(), N).as_choice();
}

constexpr std::array<W, Self::N> to_words() const { return Rep::from_rep(m_val); }
Expand Down Expand Up @@ -372,7 +372,7 @@ class IntMod final {
}

if(auto s = Self::deserialize(buf)) {
if(!s.value().is_zero()) {
if(s.value().is_nonzero().as_bool()) {
return s.value();
}
}
Expand Down Expand Up @@ -422,7 +422,7 @@ class AffineCurvePoint {

static constexpr Self identity() { return Self(FieldElement::zero(), FieldElement::zero()); }

constexpr bool is_identity() const { return x().is_zero() && y().is_zero(); }
constexpr CT::Choice is_identity() const { return x().is_zero() && y().is_zero(); }

AffineCurvePoint(const Self& other) = default;
AffineCurvePoint(Self&& other) = default;
Expand All @@ -440,8 +440,7 @@ class AffineCurvePoint {
}

constexpr void serialize_compressed_to(std::span<uint8_t, Self::COMPRESSED_BYTES> bytes) const {
const bool y_is_even = y().is_even();
const uint8_t hdr = y_is_even ? 0x02 : 0x03;
const uint8_t hdr = CT::Mask<uint8_t>::from_choice(y().is_even()).select(0x02, 0x03);

BufferStuffer pack(bytes);
pack.append(hdr);
Expand All @@ -460,7 +459,7 @@ class AffineCurvePoint {
// Intentionally wrapping; set to maximum size_t if idx == 0
const size_t idx1 = static_cast<size_t>(idx - 1);
for(size_t i = 0; i != pts.size(); ++i) {
const bool found = (idx1 == i);
const auto found = CT::Mask<size_t>::is_equal(idx1, i).as_choice();
result.conditional_assign(found, pts[i]);
}

Expand All @@ -478,7 +477,9 @@ class AffineCurvePoint {
auto y = FieldElement::deserialize(bytes.subspan(1 + FieldElement::BYTES, FieldElement::BYTES));

if(x && y) {
if((*y).square() == Self::x3_ax_b(*x)) {
const auto lhs = (*y).square();
const auto rhs = Self::x3_ax_b(*x);
if((lhs == rhs).as_bool()) {
return Self(*x, *y);
}
}
Expand All @@ -488,14 +489,12 @@ class AffineCurvePoint {
if(bytes[0] != 0x02 && bytes[0] != 0x03) {
return {};
}
const bool y_is_even = (bytes[0] == 0x02);
const CT::Choice y_is_even = CT::Mask<uint8_t>::is_equal(bytes[0], 0x02).as_choice();

if(auto x = FieldElement::deserialize(bytes.subspan(1, FieldElement::BYTES))) {
const auto y2 = x3_ax_b(*x);
auto y = y2.sqrt();
if(y_is_even && !y.is_even()) {
y = y.negate();
}
y.conditional_assign(y_is_even && !y.is_even(), y.negate());
return Self(*x, y);
}

Expand All @@ -509,7 +508,7 @@ class AffineCurvePoint {

constexpr const FieldElement& y() const { return m_y; }

constexpr void conditional_assign(bool cond, const Self& pt) {
constexpr void conditional_assign(CT::Choice cond, const Self& pt) {
m_x.conditional_assign(cond, pt.x());
m_y.conditional_assign(cond, pt.y());
}
Expand Down Expand Up @@ -537,14 +536,14 @@ class ProjectiveCurvePoint {
// recreate it here from the words.
static constexpr FieldElement A = FieldElement::from_words(Params::AW);

static constexpr bool A_is_zero = A.is_zero();
static constexpr bool A_is_minus_3 = (A == FieldElement::constant(-3));
static constexpr bool A_is_zero = A.is_zero().as_bool();
static constexpr bool A_is_minus_3 = (A == FieldElement::constant(-3)).as_bool();

using Self = ProjectiveCurvePoint<FieldElement, Params>;
using AffinePoint = AffineCurvePoint<FieldElement, Params>;

static constexpr Self from_affine(const AffinePoint& pt) {
if(pt.is_identity()) {
if(pt.is_identity().as_bool()) {
return Self::identity();
} else {
return ProjectiveCurvePoint(pt.x(), pt.y());
Expand Down Expand Up @@ -585,14 +584,9 @@ class ProjectiveCurvePoint {

friend constexpr Self operator-(const Self& a, const Self& b) { return a + b.negate(); }

constexpr bool is_identity() const { return z().is_zero(); }

template <typename Pt>
constexpr void conditional_add(bool cond, const Pt& pt) {
conditional_assign(cond, *this + pt);
}
constexpr CT::Choice is_identity() const { return z().is_zero(); }

constexpr void conditional_assign(bool cond, const Self& pt) {
constexpr void conditional_assign(CT::Choice cond, const Self& pt) {
m_x.conditional_assign(cond, pt.x());
m_y.conditional_assign(cond, pt.y());
m_z.conditional_assign(cond, pt.z());
Expand All @@ -601,7 +595,7 @@ class ProjectiveCurvePoint {
constexpr static Self add_mixed(const Self& a, const AffinePoint& b) {
const auto a_is_identity = a.is_identity();
const auto b_is_identity = b.is_identity();
if(a_is_identity && b_is_identity) {
if((a_is_identity && b_is_identity).as_bool()) {
return Self::identity();
}

Expand All @@ -618,7 +612,7 @@ class ProjectiveCurvePoint {
const auto r = S2 - a.y();

// If r is zero then we are in the doubling case
if(r.is_zero()) {
if(r.is_zero().as_bool()) {
return a.dbl();
}

Expand Down Expand Up @@ -652,7 +646,7 @@ class ProjectiveCurvePoint {
constexpr static Self add(const Self& a, const Self& b) {
const auto a_is_identity = a.is_identity();
const auto b_is_identity = b.is_identity();
if(a_is_identity && b_is_identity) {
if((a_is_identity && b_is_identity).as_bool()) {
return Self::identity();
}

Expand All @@ -669,7 +663,7 @@ class ProjectiveCurvePoint {
const auto H = U2 - U1;
const auto r = S2 - S1;

if(r.is_zero()) {
if(r.is_zero().as_bool()) {
return a.dbl();
}

Expand Down Expand Up @@ -750,7 +744,7 @@ class ProjectiveCurvePoint {
constexpr AffinePoint to_affine() const {
// Not strictly required right? - default should work as long
// as (0,0) is identity and invert returns 0 on 0
if(this->is_identity()) {
if(this->is_identity().as_bool()) {
return AffinePoint::identity();
}

Expand All @@ -769,8 +763,11 @@ class ProjectiveCurvePoint {

bool any_identity = false;
for(size_t i = 0; i != N; ++i) {
if(projective[i].is_identity()) {
if(projective[i].is_identity().as_bool()) {
any_identity = true;
// If any of the elements are the identity we fall back to
// performing the conversion without a batch
break;
}
}

Expand Down Expand Up @@ -914,7 +911,7 @@ class EllipticCurve {
static constexpr FieldElement SSWU_Z = FieldElement::constant(Params::Z);

static constexpr bool ValidForSswuHash =
(SSWU_Z.is_nonzero() && A.is_nonzero() && B.is_nonzero() && FieldElement::P_MOD_4 == 3);
(Params::Z != 0 && A.is_nonzero().as_bool() && B.is_nonzero().as_bool() && FieldElement::P_MOD_4 == 3);

// (-B / A), will be zero if A == 0 or B == 0 or Z == 0
static const FieldElement& SSWU_C1()
Expand Down Expand Up @@ -1388,7 +1385,7 @@ inline auto map_to_curve_sswu(const typename C::FieldElement& u) -> typename C::
x.conditional_assign(gx1_is_square, x1);
y.conditional_assign(gx1_is_square, gx1.sqrt());

const bool flip_y = y.is_even() != u.is_even();
const auto flip_y = y.is_even() != u.is_even();
y.conditional_assign(flip_y, y.negate());

auto pt = typename C::AffinePoint(x, y);
Expand Down
10 changes: 7 additions & 3 deletions src/lib/math/pcurves/pcurves_wrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ class PrimeOrderCurveImpl final : public PrimeOrderCurve {

ProjectivePoint point_negate(const ProjectivePoint& pt) const override { return stash(from_stash(pt).negate()); }

bool affine_point_is_identity(const AffinePoint& pt) const override { return from_stash(pt).is_identity(); }
bool affine_point_is_identity(const AffinePoint& pt) const override {
return from_stash(pt).is_identity().as_bool();
}

void serialize_point(std::span<uint8_t> bytes, const AffinePoint& pt) const override {
BOTAN_ARG_CHECK(bytes.size() == C::AffinePoint::BYTES, "Invalid length for serialize_point");
Expand Down Expand Up @@ -177,9 +179,11 @@ class PrimeOrderCurveImpl final : public PrimeOrderCurve {

Scalar scalar_negate(const Scalar& s) const override { return stash(from_stash(s).negate()); }

bool scalar_is_zero(const Scalar& s) const override { return from_stash(s).is_zero(); }
bool scalar_is_zero(const Scalar& s) const override { return from_stash(s).is_zero().as_bool(); }

bool scalar_equal(const Scalar& a, const Scalar& b) const override { return from_stash(a) == from_stash(b); }
bool scalar_equal(const Scalar& a, const Scalar& b) const override {
return (from_stash(a) == from_stash(b)).as_bool();
}

Scalar scalar_zero() const override { return stash(C::Scalar::zero()); }

Expand Down
Loading

0 comments on commit 927aab8

Please sign in to comment.