From 661758e111646958ca6538b85f5c6a8664a3b044 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Sat, 9 Sep 2023 15:57:54 -0400 Subject: [PATCH] Validate the tree on joining (#363) * Properly handle one-member trees without the one member in leaf 0 * Move uniqueness checking to tree operations * Rename WireFormat enum values to match RFC * clang-format * clang-tidy * Validate the tree on joining * Make tree validation linear-time * clang-format * Minor comment tweak --------- Co-authored-by: Richard Barnes --- include/mls/core_types.h | 8 ++ include/mls/state.h | 1 + src/state.cpp | 155 +++++++++++++++++++++++++-------------- 3 files changed, 109 insertions(+), 55 deletions(-) diff --git a/include/mls/core_types.h b/include/mls/core_types.h index cc895fa0..533c6a8c 100644 --- a/include/mls/core_types.h +++ b/include/mls/core_types.h @@ -113,6 +113,14 @@ struct Capabilities bool proposals_supported(const std::vector& required) const; bool credential_supported(const Credential& credential) const; + template + bool credentials_supported(const Container& required) const + { + return stdx::all_of(required, [&](CredentialType type) { + return stdx::contains(credentials, type); + }); + } + TLS_SERIALIZABLE(versions, cipher_suites, extensions, proposals, credentials) }; diff --git a/include/mls/state.h b/include/mls/state.h index b1ef0d93..329aadb0 100644 --- a/include/mls/state.h +++ b/include/mls/state.h @@ -277,6 +277,7 @@ class State TreeKEMPublicKey import_tree(const bytes& tree_hash, const std::optional& external, const ExtensionList& extensions); + bool validate_tree() const; // Form a commit, covering all the cases with slightly different validation // rules: diff --git a/src/state.cpp b/src/state.cpp index 9a01eea1..2dd02bfe 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -68,11 +68,77 @@ State::import_tree(const bytes& tree_hash, throw InvalidParameterError("Tree does not match GroupInfo"); } - if (!tree.parent_hash_valid()) { - throw InvalidParameterError("Invalid tree"); + return tree; +} + +bool +State::validate_tree() const +{ + // The functionality here is somewhat duplicative of State::valid(const + // LeafNode&). Simply calling that method, however, would result in this + // method having quadratic scaling, since each call to valid() does a linear + // scan through the tree to check uniqueness of keys and compatibility of + // credential support. + + // Validate that the tree is parent-hash valid + if (!_tree.parent_hash_valid()) { + return false; } - return tree; + // Validate the signatures on all leaves + const auto signature_valid = + _tree.all_leaves([&](auto i, const auto& leaf_node) { + auto binding = std::optional{}; + switch (leaf_node.source()) { + case LeafNodeSource::commit: + case LeafNodeSource::update: + binding = LeafNode::MemberBinding{ _group_id, i }; + break; + + default: + // Nothing to do + break; + } + + return leaf_node.verify(_suite, binding); + }); + if (!signature_valid) { + return false; + } + + // Collect cross-tree properties + auto n_leaves = size_t(0); + auto encryption_keys = std::set{}; + auto signature_keys = std::set{}; + auto credential_types = std::set{}; + _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + n_leaves += 1; + encryption_keys.insert(leaf_node.encryption_key.data); + signature_keys.insert(leaf_node.signature_key.data); + credential_types.insert(leaf_node.credential.type()); + return true; + }); + + // Verify uniqueness of keys + if (encryption_keys.size() != n_leaves) { + return false; + } + + if (signature_keys.size() != n_leaves) { + return false; + } + + // Verify that each leaf indicates support for all required parameters + return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + const auto supports_group_extensions = + leaf_node.verify_extension_support(_extensions); + const auto supports_own_extensions = + leaf_node.verify_extension_support(leaf_node.extensions); + const auto supports_group_credentials = + leaf_node.capabilities.credentials_supported(credential_types); + return supports_group_extensions && supports_own_extensions && + supports_group_credentials; + }); } State::State(SignaturePrivateKey sig_priv, @@ -92,6 +158,10 @@ State::State(SignaturePrivateKey sig_priv, , _index(0) , _identity_priv(std::move(sig_priv)) { + if (!validate_tree()) { + throw InvalidParameterError("Invalid tree"); + } + // The following are not set: // _index // _tree_priv @@ -174,6 +244,11 @@ State::State(const HPKEPrivateKey& init_priv, _extensions = group_info.group_context.extensions; + // Validate that the tree is in fact consistent with the group's parameters + if (!validate_tree()) { + throw InvalidParameterError("Invalid tree"); + } + // Construct TreeKEM private key from parts provided auto maybe_index = _tree.find(key_package.leaf_node); if (!maybe_index) { @@ -1130,19 +1205,9 @@ State::apply(const GroupContextExtensions& gce) bool State::extensions_supported(const ExtensionList& exts) const { - for (LeafIndex i{ 0 }; i < _tree.size; i.val++) { - const auto& maybe_leaf = _tree.leaf_node(i); - if (!maybe_leaf) { - continue; - } - - const auto& leaf = opt::get(maybe_leaf); - if (!leaf.verify_extension_support(exts)) { - return false; - } - } - - return true; + return _tree.all_leaves([&](auto /* i */, const auto& leaf_node) { + return leaf_node.verify_extension_support(exts); + }); } void @@ -1431,9 +1496,7 @@ State::valid(const LeafNode& leaf_node, // the ID for each extension in the extensions field is listed in the // capabilities.extensions field of the LeafNode. auto supports_own_extensions = - stdx::all_of(leaf_node.extensions.extensions, [&](const auto& ext) { - return stdx::contains(leaf_node.capabilities.extensions, ext.type); - }); + leaf_node.verify_extension_support(leaf_node.extensions); return (signature_valid && supports_group_extensions && correct_source && mutual_credential_support && supports_own_extensions); @@ -1536,19 +1599,7 @@ State::valid(const ExternalInit& external_init) const bool State::valid(const GroupContextExtensions& gce) const { - // Verify that each extension is supported by all members - for (auto i = LeafIndex{ 0 }; i < _tree.size; i.val++) { - const auto maybe_leaf = _tree.leaf_node(i); - if (!maybe_leaf) { - continue; - } - - const auto& leaf = opt::get(maybe_leaf); - if (!leaf.verify_extension_support(gce.group_context_extensions)) { - return false; - } - } - return true; + return extensions_supported(gce.group_context_extensions); } bool @@ -2072,19 +2123,14 @@ State::group_info(bool inline_tree) const std::vector State::roster() const { - auto leaves = std::vector(_tree.size.val); - auto leaf_count = uint32_t(0); + auto leaves = std::vector{}; + leaves.reserve(_tree.size.val); - for (uint32_t i = 0; i < _tree.size.val; i++) { - const auto& maybe_leaf = _tree.leaf_node(LeafIndex{ i }); - if (!maybe_leaf) { - continue; - } - leaves.at(leaf_count) = opt::get(maybe_leaf); - leaf_count++; - } + _tree.all_leaves([&](auto /* i */, auto leaf) { + leaves.push_back(leaf); + return true; + }); - leaves.resize(leaf_count); return leaves; } @@ -2097,20 +2143,19 @@ State::epoch_authenticator() const LeafIndex State::leaf_for_roster_entry(RosterIndex index) const { - auto non_blank_leaves = uint32_t(0); - - for (auto i = LeafIndex{ 0 }; i < _tree.size; i.val++) { - const auto& maybe_leaf = _tree.leaf_node(i); - if (!maybe_leaf) { - continue; - } - if (non_blank_leaves == index.val) { - return i; + auto visited = RosterIndex{ 0 }; + auto found = std::optional{}; + _tree.all_leaves([&](auto i, const auto& /* leaf_node */) { + if (visited == index) { + found = i; + return false; } - non_blank_leaves += 1; - } - throw InvalidParameterError("Invalid roster index"); + visited.val += 1; + return true; + }); + + return opt::get(found); } State