Skip to content

Commit

Permalink
Apply review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
FAlbertDev authored and lieser committed Oct 31, 2023
1 parent 5e53a87 commit 344e6dc
Show file tree
Hide file tree
Showing 14 changed files with 331 additions and 163 deletions.
28 changes: 17 additions & 11 deletions src/lib/pubkey/hss_lms/hss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ HSS_LMS_PrivateKeyInternal::HSS_LMS_PrivateKeyInternal(const HSS_LMS_Params& hss
std::shared_ptr<HSS_LMS_PrivateKeyInternal> HSS_LMS_PrivateKeyInternal::from_bytes_or_throw(
std::span<const uint8_t> key_bytes) {
if(key_bytes.size() < sizeof(HSS_Level) + sizeof(HSS_Sig_Idx)) {
throw Decoding_Error("To few private key bytes.");
throw Decoding_Error("Too few private key bytes.");
}
BufferSlicer slicer(key_bytes);

Expand Down Expand Up @@ -166,9 +166,7 @@ std::shared_ptr<HSS_LMS_PrivateKeyInternal> HSS_LMS_PrivateKeyInternal::from_byt
}

secure_vector<uint8_t> HSS_LMS_PrivateKeyInternal::to_bytes() const {
secure_vector<uint8_t> sk_bytes(sizeof(HSS_Level) + sizeof(HSS_Sig_Idx) +
hss_params().L().get() *
(sizeof(LMS_Algorithm_Type) + sizeof(LMOTS_Algorithm_Type)));
secure_vector<uint8_t> sk_bytes(size());
BufferStuffer stuffer(sk_bytes);

stuffer.append_be(hss_params().L());
Expand All @@ -179,9 +177,9 @@ secure_vector<uint8_t> HSS_LMS_PrivateKeyInternal::to_bytes() const {
stuffer.append_be(params.lms_params().algorithm_type());
stuffer.append_be(params.lmots_params().algorithm_type());
}

sk_bytes.insert(sk_bytes.end(), seed().begin(), seed().end());
sk_bytes.insert(sk_bytes.end(), identifier().begin(), identifier().end());
stuffer.append(seed());
stuffer.append(identifier());
BOTAN_ASSERT_NOMSG(stuffer.full());

return sk_bytes;
}
Expand All @@ -199,6 +197,14 @@ HSS_Sig_Idx HSS_LMS_PrivateKeyInternal::reserve_next_idx() {
return next_idx;
}

size_t HSS_LMS_PrivateKeyInternal::size() const {
size_t sk_size = sizeof(HSS_Level) + sizeof(HSS_Sig_Idx);
// The concatenated algorithm types for all layers
sk_size += hss_params().L().get() * (sizeof(LMS_Algorithm_Type) + sizeof(LMOTS_Algorithm_Type));
sk_size += seed().size() + identifier().size();
return sk_size;
}

HSS_LMS_PrivateKeyInternal::HSS_LMS_PrivateKeyInternal(HSS_LMS_Params hss_params,
LMS_Seed hss_seed,
LMS_Identifier identifier) :
Expand Down Expand Up @@ -299,7 +305,7 @@ HSS_LMS_PublicKeyInternal HSS_LMS_PublicKeyInternal::create(const HSS_LMS_Privat
std::shared_ptr<HSS_LMS_PublicKeyInternal> HSS_LMS_PublicKeyInternal::from_bytes_or_throw(
std::span<const uint8_t> key_bytes) {
if(key_bytes.size() < sizeof(HSS_Level)) {
throw Decoding_Error("To few public key bytes.");
throw Decoding_Error("Too few public key bytes.");
}
BufferSlicer slicer(key_bytes);

Expand All @@ -308,7 +314,7 @@ std::shared_ptr<HSS_LMS_PublicKeyInternal> HSS_LMS_PublicKeyInternal::from_bytes
throw Decoding_Error("Invalid number of HSS layers in public HSS-LMS key.");
}

LMS_PublicKey lms_pub_key = LMS_PublicKey::from_bytes_of_throw(slicer);
LMS_PublicKey lms_pub_key = LMS_PublicKey::from_bytes_or_throw(slicer);

if(!slicer.empty()) {
throw Decoding_Error("Public HSS-LMS key contains more bytes than expected.");
Expand Down Expand Up @@ -372,7 +378,7 @@ HSS_Signature::Signed_Pub_Key::Signed_Pub_Key(LMS_Signature sig, LMS_PublicKey p

HSS_Signature HSS_Signature::from_bytes_or_throw(std::span<const uint8_t> sig_bytes) {
if(sig_bytes.size() < sizeof(uint32_t)) {
throw Decoding_Error("To few HSS signature bytes.");
throw Decoding_Error("Too few HSS signature bytes.");
}
BufferSlicer slicer(sig_bytes);

Expand All @@ -384,7 +390,7 @@ HSS_Signature HSS_Signature::from_bytes_or_throw(std::span<const uint8_t> sig_by
std::vector<Signed_Pub_Key> signed_pub_keys;
for(size_t i = 0; i < Nspk; ++i) {
LMS_Signature sig = LMS_Signature::from_bytes_or_throw(slicer);
LMS_PublicKey pub_key = LMS_PublicKey::from_bytes_of_throw(slicer);
LMS_PublicKey pub_key = LMS_PublicKey::from_bytes_or_throw(slicer);
signed_pub_keys.push_back(Signed_Pub_Key(std::move(sig), std::move(pub_key)));
}

Expand Down
5 changes: 5 additions & 0 deletions src/lib/pubkey/hss_lms/hss.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ class BOTAN_TEST_API HSS_LMS_PrivateKeyInternal final {
*/
HSS_Sig_Idx reserve_next_idx();

/**
* @brief Returns the size in bytes the key would have in its encoded format.
*/
size_t size() const;

/**
* @brief Derive the seed and identifier of an LMS tree from its parent LMS tree.
*
Expand Down
5 changes: 0 additions & 5 deletions src/lib/pubkey/hss_lms/hss_lms_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ PseudorandomKeyGeneration::PseudorandomKeyGeneration(std::span<const uint8_t> id
m_input_buffer.resize(identifier.size() + sizeof(uint32_t) + sizeof(uint16_t) + sizeof(uint8_t));
BufferStuffer input_stuffer(m_input_buffer);
input_stuffer.append(identifier);

m_q = input_stuffer.next(sizeof(uint32_t));
m_i = input_stuffer.next(sizeof(uint16_t));
m_j = input_stuffer.next(sizeof(uint8_t)).data();
BOTAN_ASSERT_NOMSG(input_stuffer.full());
}

void PseudorandomKeyGeneration::gen(std::span<uint8_t> out, HashFunction& hash, std::span<const uint8_t> seed) const {
Expand Down
12 changes: 3 additions & 9 deletions src/lib/pubkey/hss_lms/hss_lms_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,17 @@ class PseudorandomKeyGeneration {
/**
* @brief Specify the value for the u32str(q) hash input field
*/
void set_q(uint32_t q) { store_be(q, m_q.data()); }
void set_q(uint32_t q) { store_be(q, std::span(m_input_buffer).last<7>().first<4>().data()); }

/**
* @brief Specify the value for the u16str(i) hash input field
*/
void set_i(uint16_t i) { store_be(i, m_i.data()); }
void set_i(uint16_t i) { store_be(i, std::span(m_input_buffer).last<3>().first<2>().data()); }

/**
* @brief Specify the value for the u8str(j) hash input field
*/
void set_j(uint8_t j) { *m_j = j; }
void set_j(uint8_t j) { m_input_buffer.back() = j; }

/**
* @brief Create a hash value using the preconfigured prefix and a @p seed
Expand All @@ -66,12 +66,6 @@ class PseudorandomKeyGeneration {
private:
/// Input buffer containing the prefix: 'identifier || u32str(q) || u16str(i) || u8str(j)'
std::vector<uint8_t> m_input_buffer;
/// Subspan of m_input_buffer representing 'u32str(q)'
std::span<uint8_t> m_q;
/// Subspan of m_input_buffer representing 'u26str(i)'
std::span<uint8_t> m_i;
/// Pointer to m_input_buffer at 'u8str(j)'
uint8_t* m_j;
};

} // namespace Botan
Expand Down
140 changes: 104 additions & 36 deletions src/lib/pubkey/hss_lms/lm_ots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,45 +101,113 @@ std::vector<uint8_t> gen_Q_with_cksm(const LMOTS_Params& params,
} // namespace

LMOTS_Params LMOTS_Params::create_or_throw(LMOTS_Algorithm_Type type) {
uint8_t type_value = checked_cast_to_or_throw<uint8_t, Decoding_Error>(type, "Unsupported LM-OTS algorithm type");

if(type >= LMOTS_Algorithm_Type::SHA256_N32_W1 && type <= LMOTS_Algorithm_Type::SHA256_N32_W8) {
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHA256_N32_W1));
return LMOTS_Params(type, "SHA-256", w);
}
if(type >= LMOTS_Algorithm_Type::SHA256_N24_W1 && type <= LMOTS_Algorithm_Type::SHA256_N24_W8) {
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHA256_N24_W1));
return LMOTS_Params(type, "Truncated(SHA-256,192)", w);
}
if(type >= LMOTS_Algorithm_Type::SHAKE_N32_W1 && type <= LMOTS_Algorithm_Type::SHAKE_N32_W8) {
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHAKE_N32_W1));
return LMOTS_Params(type, "SHAKE-256(256)", w);
}
if(type >= LMOTS_Algorithm_Type::SHAKE_N24_W1 && type <= LMOTS_Algorithm_Type::SHAKE_N24_W8) {
uint8_t w = 1 << (type_value - checked_cast_to<uint8_t>(LMOTS_Algorithm_Type::SHAKE_N24_W1));
return LMOTS_Params(type, "SHAKE-256(192)", w);
}
auto [hash_name, w] = [](const LMOTS_Algorithm_Type& lmots_type) -> std::pair<std::string_view, uint8_t> {
switch(lmots_type) {
case LMOTS_Algorithm_Type::SHA256_N32_W1:
return {"SHA-256", 1};
case LMOTS_Algorithm_Type::SHA256_N32_W2:
return {"SHA-256", 2};
case LMOTS_Algorithm_Type::SHA256_N32_W4:
return {"SHA-256", 4};
case LMOTS_Algorithm_Type::SHA256_N32_W8:
return {"SHA-256", 8};
case LMOTS_Algorithm_Type::SHA256_N24_W1:
return {"Truncated(SHA-256,192)", 1};
case LMOTS_Algorithm_Type::SHA256_N24_W2:
return {"Truncated(SHA-256,192)", 2};
case LMOTS_Algorithm_Type::SHA256_N24_W4:
return {"Truncated(SHA-256,192)", 4};
case LMOTS_Algorithm_Type::SHA256_N24_W8:
return {"Truncated(SHA-256,192)", 8};
case LMOTS_Algorithm_Type::SHAKE_N32_W1:
return {"SHAKE-256(256)", 1};
case LMOTS_Algorithm_Type::SHAKE_N32_W2:
return {"SHAKE-256(256)", 2};
case LMOTS_Algorithm_Type::SHAKE_N32_W4:
return {"SHAKE-256(256)", 4};
case LMOTS_Algorithm_Type::SHAKE_N32_W8:
return {"SHAKE-256(256)", 8};
case LMOTS_Algorithm_Type::SHAKE_N24_W1:
return {"SHAKE-256(192)", 1};
case LMOTS_Algorithm_Type::SHAKE_N24_W2:
return {"SHAKE-256(192)", 2};
case LMOTS_Algorithm_Type::SHAKE_N24_W4:
return {"SHAKE-256(192)", 4};
case LMOTS_Algorithm_Type::SHAKE_N24_W8:
return {"SHAKE-256(192)", 8};
case LMOTS_Algorithm_Type::RESERVED:
throw Decoding_Error("Unsupported LMS algorithm type");
}
throw Decoding_Error("Unsupported LMS algorithm type");
}(type);

throw Decoding_Error("Unsupported LM-OTS algorithm type");
return LMOTS_Params(type, hash_name, w);
}

LMOTS_Params LMOTS_Params::create_or_throw(std::string_view hash_name, uint8_t w) {
BOTAN_ARG_CHECK(w == 1 || w == 2 || w == 4 || w == 8, "Invalid w value");
auto type_offset = high_bit(w) - 1;
LMOTS_Algorithm_Type base_type;

if(hash_name == "SHA-256") {
base_type = LMOTS_Algorithm_Type::SHA256_N32_W1;
} else if(hash_name == "Truncated(SHA-256,192)") {
base_type = LMOTS_Algorithm_Type::SHA256_N24_W1;
} else if(hash_name == "SHAKE-256(256)") {
base_type = LMOTS_Algorithm_Type::SHAKE_N32_W1;
} else if(hash_name == "SHAKE-256(192)") {
base_type = LMOTS_Algorithm_Type::SHAKE_N24_W1;
} else {
throw Decoding_Error("Unsupported hash function");
if(w != 1 && w != 2 && w != 4 && w != 8) {
throw Decoding_Error("Invalid Winternitz parameter");
}
auto type = checked_cast_to<LMOTS_Algorithm_Type>(checked_cast_to<uint8_t>(base_type) + type_offset);
LMOTS_Algorithm_Type type = [](std::string_view hash, uint8_t w_p) -> LMOTS_Algorithm_Type {
if(hash == "SHA-256") {
switch(w_p) {
case 1:
return LMOTS_Algorithm_Type::SHA256_N32_W1;
case 2:
return LMOTS_Algorithm_Type::SHA256_N32_W2;
case 4:
return LMOTS_Algorithm_Type::SHA256_N32_W4;
case 8:
return LMOTS_Algorithm_Type::SHA256_N32_W8;
default:
throw Decoding_Error("Unsupported Winternitz parameter");
}
}
if(hash == "Truncated(SHA-256,192)") {
switch(w_p) {
case 1:
return LMOTS_Algorithm_Type::SHA256_N24_W1;
case 2:
return LMOTS_Algorithm_Type::SHA256_N24_W2;
case 4:
return LMOTS_Algorithm_Type::SHA256_N24_W4;
case 8:
return LMOTS_Algorithm_Type::SHA256_N24_W8;
default:
throw Decoding_Error("Unsupported Winternitz parameter");
}
}
if(hash == "SHAKE-256(256)") {
switch(w_p) {
case 1:
return LMOTS_Algorithm_Type::SHAKE_N32_W1;
case 2:
return LMOTS_Algorithm_Type::SHAKE_N32_W2;
case 4:
return LMOTS_Algorithm_Type::SHAKE_N32_W4;
case 8:
return LMOTS_Algorithm_Type::SHAKE_N32_W8;
default:
throw Decoding_Error("Unsupported Winternitz parameter");
}
}
if(hash == "SHAKE-256(192)") {
switch(w_p) {
case 1:
return LMOTS_Algorithm_Type::SHAKE_N24_W1;
case 2:
return LMOTS_Algorithm_Type::SHAKE_N24_W2;
case 4:
return LMOTS_Algorithm_Type::SHAKE_N24_W4;
case 8:
return LMOTS_Algorithm_Type::SHAKE_N24_W8;
default:
throw Decoding_Error("Unsupported Winternitz parameter");
}
}
throw Decoding_Error("Unsupported hash function");
}(hash_name, w);

return LMOTS_Params(type, hash_name, w);
}

Expand Down Expand Up @@ -171,7 +239,7 @@ LMOTS_Signature LMOTS_Signature::from_bytes_or_throw(BufferSlicer& slicer) {
size_t total_remaining_bytes = slicer.remaining();
// Alg. 6a. 1. (last 4 bytes) / Alg. 4b. 1.
if(total_remaining_bytes < sizeof(LMOTS_Algorithm_Type)) {
throw Decoding_Error("To few signature bytes while parsing LMOTS signature.");
throw Decoding_Error("Too few signature bytes while parsing LMOTS signature.");
}
// Alg. 6a. 2.b. / Alg. 4b. 2.a.
auto algorithm_type = slicer.copy_be<LMOTS_Algorithm_Type>();
Expand All @@ -180,7 +248,7 @@ LMOTS_Signature LMOTS_Signature::from_bytes_or_throw(BufferSlicer& slicer) {
LMOTS_Params params = LMOTS_Params::create_or_throw(algorithm_type);

if(total_remaining_bytes < size(params)) {
throw Decoding_Error("To few signature bytes while parsing LMOTS signature.");
throw Decoding_Error("Too few signature bytes while parsing LMOTS signature.");
}

// Alg. 4b. 2.d.
Expand Down

0 comments on commit 344e6dc

Please sign in to comment.