Skip to content

Commit 9f9808b

Browse files
committed
populate roster api on the session
1 parent 7a359ba commit 9f9808b

File tree

8 files changed

+38
-2
lines changed

8 files changed

+38
-2
lines changed

cmd/api_example/main.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ verify(const std::string& label, Session& alice, Session& bob)
3636
throw std::runtime_error(label + ": not equal");
3737
}
3838

39+
if (alice.roster() != bob.roster()) {
40+
throw std::runtime_error(label + ": roster not equal");
41+
}
42+
3943
verify_send(label, alice, bob);
4044
verify_send(label, bob, alice);
4145
}

include/mls/credential.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ operator<<(tls::ostream& str, const X509Credential& obj);
6868
tls::istream&
6969
operator>>(tls::istream& str, X509Credential& obj);
7070

71+
bool
72+
operator==(const X509Credential& lhs, const X509Credential& rhs);
73+
7174
// struct {
7275
// CredentialType credential_type;
7376
// select (credential_type) {

include/mls/session.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ class Session
6868
bytes do_export(const std::string& label,
6969
const bytes& context,
7070
size_t size) const;
71+
std::vector<Credential> roster() const;
7172

7273
// Application message protection
7374
bytes protect(const bytes& plaintext);

include/mls/state.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class State
8383
bytes do_export(const std::string& label,
8484
const bytes& context,
8585
size_t size) const;
86+
std::vector<Credential> get_leaf_credentials() const;
8687

8788
///
8889
/// General encryption and decryption

src/credential.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ operator>>(tls::istream& str, X509Credential& obj)
8686
return str;
8787
}
8888

89+
bool
90+
operator==(const X509Credential& lhs, const X509Credential& rhs)
91+
{
92+
return lhs.der_chain == rhs.der_chain;
93+
}
94+
8995
///
9096
/// Credential
9197
///

src/session.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,12 @@ Session::do_export(const std::string& label,
347347
return inner->history.front().do_export(label, context, size);
348348
}
349349

350+
std::vector<Credential>
351+
Session::roster() const
352+
{
353+
return inner->history.front().get_leaf_credentials();
354+
}
355+
350356
bytes
351357
Session::protect(const bytes& plaintext)
352358
{

src/state.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,21 @@ State::do_export(const std::string& label,
580580
return _suite.expand_with_label(secret, "exporter", context, size);
581581
}
582582

583+
std::vector<Credential>
584+
State::get_leaf_credentials() const
585+
{
586+
std::vector<Credential> creds;
587+
588+
for (uint32_t i = 0; i < _tree.size().val; i++) {
589+
const auto& ln = _tree.node_at(LeafIndex{ i });
590+
if (ln.node.has_value()) {
591+
creds.push_back(ln.key_package().credential);
592+
}
593+
}
594+
595+
return creds;
596+
}
597+
583598
MLSCiphertext
584599
State::encrypt(const MLSPlaintext& pt)
585600
{

test/credential.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ TEST_CASE("X509 Credential Depth 2 Marshal/Unmarshal")
7474

7575
auto marshalled = tls::marshal(original);
7676
auto unmarshaled = tls::get<Credential>(marshalled);
77-
CHECK(original.public_key() == unmarshaled.public_key());
77+
CHECK(original == unmarshaled);
7878

7979
auto x509 = unmarshaled.get<X509Credential>();
8080
CHECK(x509.der_chain == der_in);
@@ -99,7 +99,7 @@ TEST_CASE("X509 Credential Depth 1 Marshal/Unmarshal")
9999

100100
auto marshalled = tls::marshal(original);
101101
auto unmarshaled = tls::get<Credential>(marshalled);
102-
CHECK(original.public_key() == unmarshaled.public_key());
102+
CHECK(original == unmarshaled);
103103

104104
auto x509 = unmarshaled.get<X509Credential>();
105105
CHECK(x509.der_chain == der_in);

0 commit comments

Comments
 (0)