Skip to content

Commit

Permalink
Validate the tree on joining (#363)
Browse files Browse the repository at this point in the history
* Properly handle one-member trees without the one member in leaf 0

* Move uniqueness checking to tree operations

* Rename WireFormat enum values to match RFC

* clang-format

* clang-tidy

* Validate the tree on joining

* Make tree validation linear-time

* clang-format

* Minor comment tweak

---------

Co-authored-by: Richard Barnes <[email protected]>
  • Loading branch information
bifurcation and Richard Barnes authored Sep 9, 2023
1 parent a684f2b commit 661758e
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 55 deletions.
8 changes: 8 additions & 0 deletions include/mls/core_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,14 @@ struct Capabilities
bool proposals_supported(const std::vector<uint16_t>& required) const;
bool credential_supported(const Credential& credential) const;

template<typename Container>
bool credentials_supported(const Container& required) const
{
return stdx::all_of(required, [&](CredentialType type) {
return stdx::contains(credentials, type);
});
}

TLS_SERIALIZABLE(versions, cipher_suites, extensions, proposals, credentials)
};

Expand Down
1 change: 1 addition & 0 deletions include/mls/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ class State
TreeKEMPublicKey import_tree(const bytes& tree_hash,
const std::optional<TreeKEMPublicKey>& external,
const ExtensionList& extensions);
bool validate_tree() const;

// Form a commit, covering all the cases with slightly different validation
// rules:
Expand Down
155 changes: 100 additions & 55 deletions src/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,77 @@ State::import_tree(const bytes& tree_hash,
throw InvalidParameterError("Tree does not match GroupInfo");
}

if (!tree.parent_hash_valid()) {
throw InvalidParameterError("Invalid tree");
return tree;
}

bool
State::validate_tree() const
{
// The functionality here is somewhat duplicative of State::valid(const
// LeafNode&). Simply calling that method, however, would result in this
// method having quadratic scaling, since each call to valid() does a linear
// scan through the tree to check uniqueness of keys and compatibility of
// credential support.

// Validate that the tree is parent-hash valid
if (!_tree.parent_hash_valid()) {
return false;
}

return tree;
// Validate the signatures on all leaves
const auto signature_valid =
_tree.all_leaves([&](auto i, const auto& leaf_node) {
auto binding = std::optional<LeafNode::MemberBinding>{};
switch (leaf_node.source()) {
case LeafNodeSource::commit:
case LeafNodeSource::update:
binding = LeafNode::MemberBinding{ _group_id, i };
break;

default:
// Nothing to do
break;
}

return leaf_node.verify(_suite, binding);
});
if (!signature_valid) {
return false;
}

// Collect cross-tree properties
auto n_leaves = size_t(0);
auto encryption_keys = std::set<bytes>{};
auto signature_keys = std::set<bytes>{};
auto credential_types = std::set<CredentialType>{};
_tree.all_leaves([&](auto /* i */, const auto& leaf_node) {
n_leaves += 1;
encryption_keys.insert(leaf_node.encryption_key.data);
signature_keys.insert(leaf_node.signature_key.data);
credential_types.insert(leaf_node.credential.type());
return true;
});

// Verify uniqueness of keys
if (encryption_keys.size() != n_leaves) {
return false;
}

if (signature_keys.size() != n_leaves) {
return false;
}

// Verify that each leaf indicates support for all required parameters
return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) {
const auto supports_group_extensions =
leaf_node.verify_extension_support(_extensions);
const auto supports_own_extensions =
leaf_node.verify_extension_support(leaf_node.extensions);
const auto supports_group_credentials =
leaf_node.capabilities.credentials_supported(credential_types);
return supports_group_extensions && supports_own_extensions &&
supports_group_credentials;
});
}

State::State(SignaturePrivateKey sig_priv,
Expand All @@ -92,6 +158,10 @@ State::State(SignaturePrivateKey sig_priv,
, _index(0)
, _identity_priv(std::move(sig_priv))
{
if (!validate_tree()) {
throw InvalidParameterError("Invalid tree");
}

// The following are not set:
// _index
// _tree_priv
Expand Down Expand Up @@ -174,6 +244,11 @@ State::State(const HPKEPrivateKey& init_priv,

_extensions = group_info.group_context.extensions;

// Validate that the tree is in fact consistent with the group's parameters
if (!validate_tree()) {
throw InvalidParameterError("Invalid tree");
}

// Construct TreeKEM private key from parts provided
auto maybe_index = _tree.find(key_package.leaf_node);
if (!maybe_index) {
Expand Down Expand Up @@ -1130,19 +1205,9 @@ State::apply(const GroupContextExtensions& gce)
bool
State::extensions_supported(const ExtensionList& exts) const
{
for (LeafIndex i{ 0 }; i < _tree.size; i.val++) {
const auto& maybe_leaf = _tree.leaf_node(i);
if (!maybe_leaf) {
continue;
}

const auto& leaf = opt::get(maybe_leaf);
if (!leaf.verify_extension_support(exts)) {
return false;
}
}

return true;
return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) {
return leaf_node.verify_extension_support(exts);
});
}

void
Expand Down Expand Up @@ -1431,9 +1496,7 @@ State::valid(const LeafNode& leaf_node,
// the ID for each extension in the extensions field is listed in the
// capabilities.extensions field of the LeafNode.
auto supports_own_extensions =
stdx::all_of(leaf_node.extensions.extensions, [&](const auto& ext) {
return stdx::contains(leaf_node.capabilities.extensions, ext.type);
});
leaf_node.verify_extension_support(leaf_node.extensions);

return (signature_valid && supports_group_extensions && correct_source &&
mutual_credential_support && supports_own_extensions);
Expand Down Expand Up @@ -1536,19 +1599,7 @@ State::valid(const ExternalInit& external_init) const
bool
State::valid(const GroupContextExtensions& gce) const
{
// Verify that each extension is supported by all members
for (auto i = LeafIndex{ 0 }; i < _tree.size; i.val++) {
const auto maybe_leaf = _tree.leaf_node(i);
if (!maybe_leaf) {
continue;
}

const auto& leaf = opt::get(maybe_leaf);
if (!leaf.verify_extension_support(gce.group_context_extensions)) {
return false;
}
}
return true;
return extensions_supported(gce.group_context_extensions);
}

bool
Expand Down Expand Up @@ -2072,19 +2123,14 @@ State::group_info(bool inline_tree) const
std::vector<LeafNode>
State::roster() const
{
auto leaves = std::vector<LeafNode>(_tree.size.val);
auto leaf_count = uint32_t(0);
auto leaves = std::vector<LeafNode>{};
leaves.reserve(_tree.size.val);

for (uint32_t i = 0; i < _tree.size.val; i++) {
const auto& maybe_leaf = _tree.leaf_node(LeafIndex{ i });
if (!maybe_leaf) {
continue;
}
leaves.at(leaf_count) = opt::get(maybe_leaf);
leaf_count++;
}
_tree.all_leaves([&](auto /* i */, auto leaf) {
leaves.push_back(leaf);
return true;
});

leaves.resize(leaf_count);
return leaves;
}

Expand All @@ -2097,20 +2143,19 @@ State::epoch_authenticator() const
LeafIndex
State::leaf_for_roster_entry(RosterIndex index) const
{
auto non_blank_leaves = uint32_t(0);

for (auto i = LeafIndex{ 0 }; i < _tree.size; i.val++) {
const auto& maybe_leaf = _tree.leaf_node(i);
if (!maybe_leaf) {
continue;
}
if (non_blank_leaves == index.val) {
return i;
auto visited = RosterIndex{ 0 };
auto found = std::optional<LeafIndex>{};
_tree.all_leaves([&](auto i, const auto& /* leaf_node */) {
if (visited == index) {
found = i;
return false;
}
non_blank_leaves += 1;
}

throw InvalidParameterError("Invalid roster index");
visited.val += 1;
return true;
});

return opt::get(found);
}

State
Expand Down

0 comments on commit 661758e

Please sign in to comment.