Skip to content

Commit

Permalink
Update to latest key schedule test vector (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
bifurcation authored Feb 21, 2023
1 parent 1322340 commit 3482943
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 37 deletions.
9 changes: 8 additions & 1 deletion cmd/interop/src/json_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,15 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SecretTreeTestVector,
sender_data,
leaves)

NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(KeyScheduleTestVector::Export,
label,
context,
length,
secret)
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(KeyScheduleTestVector::Epoch,
tree_hash,
commit_secret,
psk_secret,
confirmed_transcript_hash,
group_context,
joiner_secret,
Expand All @@ -224,7 +230,8 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(KeyScheduleTestVector::Epoch,
confirmation_key,
membership_key,
resumption_psk,
external_pub)
external_pub,
exporter)
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(KeyScheduleTestVector,
cipher_suite,
group_id,
Expand Down
11 changes: 9 additions & 2 deletions cmd/interop/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,15 @@ make_test_vector(uint64_t type)
case TestVectorType::TREE_MATH:
return TreeMathTestVector{ n };

case TestVectorType::KEY_SCHEDULE:
return KeyScheduleTestVector{ suite, n };
case TestVectorType::KEY_SCHEDULE: {
auto cases = std::vector<KeyScheduleTestVector>();

for (const auto& suite : mls::all_supported_suites) {
cases.emplace_back(suite, n);
}

return cases;
}

case TestVectorType::TRANSCRIPT:
return TranscriptTestVector{ suite };
Expand Down
36 changes: 24 additions & 12 deletions include/mls/key_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ struct KeyScheduleEpoch

public:
bytes joiner_secret;
bytes psk_secret;
bytes epoch_secret;

bytes sender_data_secret;
Expand All @@ -115,10 +114,10 @@ struct KeyScheduleEpoch
KeyScheduleEpoch() = default;

// Full initializer, used by invited joiner
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context);
static KeyScheduleEpoch joiner(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context);

// Ciphersuite-only initializer, used by external joiner
KeyScheduleEpoch(CipherSuite suite_in);
Expand All @@ -128,13 +127,6 @@ struct KeyScheduleEpoch
const bytes& init_secret,
const bytes& context);

// Subsequent epochs
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context);

static std::tuple<bytes, bytes> external_init(
CipherSuite suite,
const HPKEPublicKey& external_pub);
Expand Down Expand Up @@ -162,6 +154,26 @@ struct KeyScheduleEpoch
static KeyAndNonce sender_data_keys(CipherSuite suite,
const bytes& sender_data_secret,
const bytes& ciphertext);

// TODO(RLB) make these methods private, but accessible to test vectors
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& commit_secret,
const bytes& psk_secret,
const bytes& context);
KeyScheduleEpoch next_raw(const bytes& commit_secret,
const bytes& psk_secret,
const std::optional<bytes>& force_init_secret,
const bytes& context) const;
static bytes welcome_secret_raw(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret);

private:
KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const bytes& psk_secret,
const bytes& context);
};

bool
Expand Down
8 changes: 5 additions & 3 deletions lib/mls_vectors/include/mls_vectors/mls_vectors.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,18 @@ struct KeyScheduleTestVector : PseudoRandom
{
struct Export
{
std::string exporter_label;
size_t exporter_length;
bytes exported;
std::string label;
bytes context;
size_t length;
bytes secret;
};

struct Epoch
{
// Chosen by the generator
bytes tree_hash;
bytes commit_secret;
bytes psk_secret;
bytes confirmed_transcript_hash;

// Computed values
Expand Down
33 changes: 20 additions & 13 deletions lib/mls_vectors/src/mls_vectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,19 +584,24 @@ KeyScheduleTestVector::KeyScheduleTestVector(CipherSuite suite,
epoch_prg.secret("confirmed_transcript_hash");
auto ctx = tls::marshal(group_context);

auto commit_secret = epoch_prg.secret("commit_secret");
// TODO(RLB) Add Test case for externally-driven epoch change
epoch = epoch.next(commit_secret, {}, std::nullopt, ctx);
auto commit_secret = epoch_prg.secret("commit_secret");
auto psk_secret = epoch_prg.secret("psk_secret");
epoch = epoch.next_raw(commit_secret, psk_secret, std::nullopt, ctx);

auto welcome_secret =
KeyScheduleEpoch::welcome_secret(cipher_suite, epoch.joiner_secret, {});
auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw(
cipher_suite, epoch.joiner_secret, psk_secret);

auto exporter_label = to_hex(epoch_prg.secret("exporter_label"));
auto exporter_prg = epoch_prg.sub("exporter");
auto exporter_label = to_hex(exporter_prg.secret("label"));
auto exporter_context = exporter_prg.secret("context");
auto exporter_length = cipher_suite.secret_size();
auto exported = epoch.do_export(exporter_label, {}, exporter_length);
auto exported =
epoch.do_export(exporter_label, exporter_context, exporter_length);

epochs.push_back({ group_context.tree_hash,
commit_secret,
psk_secret,
group_context.confirmed_transcript_hash,

ctx,
Expand All @@ -618,6 +623,7 @@ KeyScheduleTestVector::KeyScheduleTestVector(CipherSuite suite,

{
exporter_label,
exporter_context,
exporter_length,
exported,
} });
Expand All @@ -639,13 +645,14 @@ KeyScheduleTestVector::verify() const
auto ctx = tls::marshal(group_context);
VERIFY_EQUAL("group context", ctx, tve.group_context);

epoch = epoch.next(tve.commit_secret, {}, std::nullopt, ctx);
epoch =
epoch.next_raw(tve.commit_secret, tve.psk_secret, std::nullopt, ctx);

// Verify the rest of the epoch
VERIFY_EQUAL("joiner secret", epoch.joiner_secret, tve.joiner_secret);

auto welcome_secret =
KeyScheduleEpoch::welcome_secret(cipher_suite, tve.joiner_secret, {});
auto welcome_secret = KeyScheduleEpoch::welcome_secret_raw(
cipher_suite, tve.joiner_secret, tve.psk_secret);
VERIFY_EQUAL("welcome secret", welcome_secret, tve.welcome_secret);

VERIFY_EQUAL(
Expand All @@ -667,8 +674,8 @@ KeyScheduleTestVector::verify() const
"external pub", epoch.external_priv.public_key, tve.external_pub);

auto exported = epoch.do_export(
tve.exporter.exporter_label, {}, tve.exporter.exporter_length);
VERIFY_EQUAL("exported", exported, tve.exporter.exported);
tve.exporter.label, tve.exporter.context, tve.exporter.length);
VERIFY_EQUAL("exported", exported, tve.exporter.secret);

group_context.epoch += 1;
}
Expand Down Expand Up @@ -1063,7 +1070,7 @@ WelcomeTestVector::WelcomeTestVector(CipherSuite suite)
cipher_suite, group_id, epoch, tree_hash, confirmed_transcript_hash, {}
};

auto key_schedule = KeyScheduleEpoch(
auto key_schedule = KeyScheduleEpoch::joiner(
cipher_suite, joiner_secret, {}, tls::marshal(group_context));
auto confirmation_tag =
key_schedule.confirmation_tag(confirmed_transcript_hash);
Expand Down Expand Up @@ -1106,7 +1113,7 @@ WelcomeTestVector::verify() const

// Verify confirmation tag
const auto& group_context = group_info.group_context;
auto key_schedule = KeyScheduleEpoch(
auto key_schedule = KeyScheduleEpoch::joiner(
cipher_suite, group_secrets.joiner_secret, {}, tls::marshal(group_context));
auto confirmation_tag =
key_schedule.confirmation_tag(group_context.confirmed_transcript_hash);
Expand Down
36 changes: 31 additions & 5 deletions src/key_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,21 @@ make_epoch_secret(CipherSuite suite,
member_secret, "epoch", context, suite.secret_size());
}

KeyScheduleEpoch
KeyScheduleEpoch::joiner(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& context)
{
return { suite_in, joiner_secret, make_psk_secret(suite_in, psks), context };
}

KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& joiner_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& psk_secret,
const bytes& context)
: suite(suite_in)
, joiner_secret(joiner_secret)
, psk_secret(make_psk_secret(suite_in, psks))
, epoch_secret(
make_epoch_secret(suite_in, joiner_secret, psk_secret, context))
, sender_data_secret(suite.derive_secret(epoch_secret, "sender data"))
Expand Down Expand Up @@ -337,12 +345,12 @@ KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
KeyScheduleEpoch::KeyScheduleEpoch(CipherSuite suite_in,
const bytes& init_secret,
const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const bytes& psk_secret,
const bytes& context)
: KeyScheduleEpoch(
suite_in,
make_joiner_secret(suite_in, context, init_secret, commit_secret),
psks,
psk_secret,
context)
{
}
Expand All @@ -369,13 +377,23 @@ KeyScheduleEpoch::next(const bytes& commit_secret,
const std::vector<PSKWithSecret>& psks,
const std::optional<bytes>& force_init_secret,
const bytes& context) const
{
return next_raw(
commit_secret, make_psk_secret(suite, psks), force_init_secret, context);
}

KeyScheduleEpoch
KeyScheduleEpoch::next_raw(const bytes& commit_secret,
const bytes& psk_secret,
const std::optional<bytes>& force_init_secret,
const bytes& context) const
{
auto actual_init_secret = init_secret;
if (force_init_secret) {
actual_init_secret = opt::get(force_init_secret);
}

return { suite, actual_init_secret, commit_secret, psks, context };
return { suite, actual_init_secret, commit_secret, psk_secret, context };
}

GroupKeySource
Expand Down Expand Up @@ -434,6 +452,14 @@ KeyScheduleEpoch::welcome_secret(CipherSuite suite,
const std::vector<PSKWithSecret>& psks)
{
auto psk_secret = make_psk_secret(suite, psks);
return welcome_secret_raw(suite, joiner_secret, psk_secret);
}

bytes
KeyScheduleEpoch::welcome_secret_raw(CipherSuite suite,
const bytes& joiner_secret,
const bytes& psk_secret)
{
auto extract = suite.hpke().kdf.extract(joiner_secret, psk_secret);
return suite.derive_secret(extract, "welcome");
}
Expand Down
2 changes: 1 addition & 1 deletion src/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ State::State(const HPKEPrivateKey& init_priv,

// Ratchet forward into the current epoch
auto group_ctx = tls::marshal(group_context());
_key_schedule = KeyScheduleEpoch(
_key_schedule = KeyScheduleEpoch::joiner(
_suite, secrets.joiner_secret, { /* no PSKs */ }, group_ctx);
_keys = _key_schedule.encryption_keys(_tree.size);

Expand Down

0 comments on commit 3482943

Please sign in to comment.