Skip to content

Commit

Permalink
[WIP] Refactored policy_node_keyexpr_t to explicitly label which of t…
Browse files Browse the repository at this point in the history
…he union type is used; annotated parts of the code that are not generalized to musig key expressions
  • Loading branch information
bigspider committed Feb 23, 2024
1 parent 9847733 commit ef7dd97
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/common/wallet.c
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ static int parse_keyexpr(buffer_t *in_buf,
return WITH_ERROR(-1, "The key index in a placeholder must be at most 32767");
}

out->key_index = (int16_t) k;
out->k.key_index = (int16_t) k;
} else if (c == 'm') {
// parse a musig(key1,...,keyn) expression, where each key is a key expression
if (!consume_characters(in_buf, "usig(", 5)) {
Expand Down Expand Up @@ -531,7 +531,7 @@ static int parse_keyexpr(buffer_t *in_buf,
musig_info->n = n_musig_keys;
i_uint16(&musig_info->key_indexes, key_indexes);

i_musig_aggr_key_info(&out->musig_info, musig_info);
i_musig_aggr_key_info(&out->m.musig_info, musig_info);
} else {
return WITH_ERROR(-1, "Expected key placeholder starting with '@', or musig");
}
Expand Down
9 changes: 7 additions & 2 deletions src/common/wallet.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,9 +333,14 @@ typedef struct {
KeyExpressionType type;
union {
// type == 0
int16_t key_index; // index of the key (common between V1 and V2)
struct {
int16_t key_index; // index of the key (common between V1 and V2)
} k;

// type == 1
rptr_musig_aggr_key_info_t musig_info;
struct {
rptr_musig_aggr_key_info_t musig_info;
} m;
};
} policy_node_keyexpr_t;

Expand Down
65 changes: 43 additions & 22 deletions src/handler/lib/policy.c
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,12 @@ __attribute__((warn_unused_result)) static int get_derived_pubkey(

serialized_extended_pubkey_t ext_pubkey;

int ret = get_extended_pubkey(dispatcher_context, wdi, key_expr->key_index, &ext_pubkey);
if (key_expr->type != KEY_EXPRESSION_NORMAL) {
PRINTF("Not implemented\n"); // TODO
return -1;
}

int ret = get_extended_pubkey(dispatcher_context, wdi, key_expr->k.key_index, &ext_pubkey);
if (ret < 0) {
return -1;
}
Expand Down Expand Up @@ -1376,7 +1381,12 @@ static int get_bip44_purpose(const policy_node_t *descriptor_template) {
return -1;
}

if (kp->key_index != 0 || kp->num_first != 0 || kp->num_second != 1) {
if (kp->type != KEY_EXPRESSION_NORMAL) {
// any key expression that is not a play xpub is not BIP-44 compliant
return -1;
}

if (kp->k.key_index != 0 || kp->num_first != 0 || kp->num_second != 1) {
return -1;
}

Expand Down Expand Up @@ -1508,7 +1518,7 @@ bool check_wallet_hmac(const uint8_t wallet_id[static 32], const uint8_t wallet_
static int get_keyexpr_by_index_in_tree(const policy_node_tree_t *tree,
unsigned int i,
const policy_node_t **out_tapleaf_ptr,
policy_node_keyexpr_t *out_keyexpr) {
policy_node_keyexpr_t **out_keyexpr) {
if (tree->is_leaf) {
int ret = get_keyexpr_by_index(r_policy_node(&tree->script), i, NULL, out_keyexpr);
if (ret >= 0 && out_tapleaf_ptr != NULL && i < (unsigned) ret) {
Expand All @@ -1534,16 +1544,12 @@ static int get_keyexpr_by_index_in_tree(const policy_node_tree_t *tree,
}
}

// TODO: generalize for musig. Note that this is broken for musig, as out_keyexpr
// can't be filled in for musig key expressions (as it's dynamic and contains
// relative pointers). We should probably refactor to return the pointer to the
// key expression and removing the out_keyexpr argument.
int get_keyexpr_by_index(const policy_node_t *policy,
unsigned int i,
const policy_node_t **out_tapleaf_ptr,
policy_node_keyexpr_t *out_keyexpr) {
policy_node_keyexpr_t **out_keyexpr) {
// make sure that out_keyexpr is a valid pointer, if the output is not needed
policy_node_keyexpr_t tmp;
policy_node_keyexpr_t *tmp;
if (out_keyexpr == NULL) {
out_keyexpr = &tmp;
}
Expand All @@ -1568,16 +1574,14 @@ int get_keyexpr_by_index(const policy_node_t *policy,
case TOKEN_WPKH: {
if (i == 0) {
policy_node_with_key_t *wpkh = (policy_node_with_key_t *) policy;
memcpy(out_keyexpr,
r_policy_node_keyexpr(&wpkh->key),
sizeof(policy_node_keyexpr_t));
*out_keyexpr = r_policy_node_keyexpr(&wpkh->key);
}
return 1;
}
case TOKEN_TR: {
policy_node_tr_t *tr = (policy_node_tr_t *) policy;
if (i == 0) {
memcpy(out_keyexpr, r_policy_node_keyexpr(&tr->key), sizeof(policy_node_keyexpr_t));
*out_keyexpr = r_policy_node_keyexpr(&tr->key);
}
if (!isnull_policy_node_tree(&tr->tree)) {
int ret_tree = get_keyexpr_by_index_in_tree(
Expand All @@ -1604,7 +1608,7 @@ int get_keyexpr_by_index(const policy_node_t *policy,

if (i < (unsigned int) node->n) {
policy_node_keyexpr_t *key_expressions = r_policy_node_keyexpr(&node->keys);
memcpy(out_keyexpr, &key_expressions[i], sizeof(policy_node_keyexpr_t));
*out_keyexpr = &key_expressions[i];
}

return node->n;
Expand Down Expand Up @@ -1715,16 +1719,24 @@ int get_keyexpr_by_index(const policy_node_t *policy,
}

int count_distinct_keys_info(const policy_node_t *policy) {
policy_node_keyexpr_t key_expression;
policy_node_keyexpr_t *key_expression_ptr;
int ret = -1, cur, n_key_expressions;

for (cur = 0;
cur < (n_key_expressions = get_keyexpr_by_index(policy, cur, NULL, &key_expression));
cur < (n_key_expressions = get_keyexpr_by_index(policy, cur, NULL, &key_expression_ptr));
++cur) {
if (n_key_expressions < 0) {
return -1;
}
ret = MAX(ret, key_expression.key_index + 1);
if (key_expression_ptr->type == KEY_EXPRESSION_NORMAL) {
ret = MAX(ret, key_expression_ptr->k.key_index + 1);
} else if (key_expression_ptr->type == KEY_EXPRESSION_MUSIG) {
musig_aggr_key_info_t *musig_info =
r_musig_aggr_key_info(&key_expression_ptr->m.musig_info);
ret = MAX(ret, musig_info->n);
} else {
LEDGER_ASSERT(false, "Unknown key expression type");
}
}
return ret;
}
Expand Down Expand Up @@ -1912,21 +1924,30 @@ int is_policy_sane(dispatcher_context_t *dispatcher_context,
// proportional to the depth of the wallet policy's abstract syntax tree.
for (int i = 0; i < n_key_expressions - 1;
i++) { // no point in running this for the last key expression
policy_node_keyexpr_t kp_i;
policy_node_keyexpr_t *kp_i;
if (0 > get_keyexpr_by_index(policy, i, NULL, &kp_i)) {
return WITH_ERROR(-1, "Unexpected error retrieving key expressions from the policy");
}
for (int j = i + 1; j < n_key_expressions; j++) {
policy_node_keyexpr_t kp_j;
policy_node_keyexpr_t *kp_j;
if (0 > get_keyexpr_by_index(policy, j, NULL, &kp_j)) {
return WITH_ERROR(-1,
"Unexpected error retrieving key expressions from the policy");
}

if (kp_i->type != kp_j->type) {
// if one is a key and the other is a musig, there's nothing else to check
continue;
}

LEDGER_ASSERT(
kp_i->type == KEY_EXPRESSION_NORMAL && kp_j->type == KEY_EXPRESSION_NORMAL,
"TODO");

// key expressions for the same key must have disjoint derivation options
if (kp_i.key_index == kp_j.key_index) {
if (kp_i.num_first == kp_j.num_first || kp_i.num_first == kp_j.num_second ||
kp_i.num_second == kp_j.num_first || kp_i.num_second == kp_j.num_second) {
if (kp_i->k.key_index == kp_j->k.key_index) {
if (kp_i->num_first == kp_j->num_first || kp_i->num_first == kp_j->num_second ||
kp_i->num_second == kp_j->num_first || kp_i->num_second == kp_j->num_second) {
return WITH_ERROR(-1,
"Key expressions with repeated derivations in miniscript");
}
Expand Down
5 changes: 3 additions & 2 deletions src/handler/lib/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ bool check_wallet_hmac(const uint8_t wallet_id[static 32], const uint8_t wallet_
* If not NULL, and if the i-th key expression is in a tapleaf of the policy, receives the pointer
* to the tapleaf's script.
* @param[out] out_keyexpr
* If not NULL, it is a pointer that will receive the i-th key expression of the policy.
* If not NULL, it is a pointer that will receive a pointer to the i-th key expression of the
* policy.
* @return the number of key expressions in the policy on success; -1 in case of error.
*/
__attribute__((warn_unused_result)) int get_keyexpr_by_index(const policy_node_t *policy,
unsigned int i,
const policy_node_t **out_tapleaf_ptr,
policy_node_keyexpr_t *out_keyexpr);
policy_node_keyexpr_t **out_keyexpr);

/**
* Determines the expected number of unique keys in the provided policy's key information.
Expand Down
41 changes: 23 additions & 18 deletions src/handler/sign_psbt.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ typedef struct {
} output_info_t;

typedef struct {
policy_node_keyexpr_t key_expression;
policy_node_keyexpr_t *key_expression_ptr;
int cur_index;
uint32_t fingerprint;
uint8_t key_derivation_length;
Expand Down Expand Up @@ -451,10 +451,10 @@ static int read_change_and_index_from_psbt_bip32_derivation(
}

// check if the 'change' derivation step is indeed coherent with the key expression
if (change == keyexpr_info->key_expression.num_first) {
if (change == keyexpr_info->key_expression_ptr->num_first) {
in_out->is_change = false;
in_out->address_index = addr_index;
} else if (change == keyexpr_info->key_expression.num_second) {
} else if (change == keyexpr_info->key_expression_ptr->num_second) {
in_out->is_change = true;
in_out->address_index = addr_index;
} else {
Expand Down Expand Up @@ -710,12 +710,17 @@ static bool __attribute__((noinline)) fill_keyexpr_info_if_internal(dispatcher_c
policy_map_key_info_t key_info;
{
uint8_t key_info_str[MAX_POLICY_KEY_INFO_LEN];
int key_info_len = call_get_merkle_leaf_element(dc,
st->wallet_header_keys_info_merkle_root,
st->wallet_header_n_keys,
keyexpr_info->key_expression.key_index,
key_info_str,
sizeof(key_info_str));

// TODO: generalize for musig: keyexpr_info->key_expression->key_index is wrong
LEDGER_ASSERT(keyexpr_info->key_expression_ptr->type == KEY_EXPRESSION_NORMAL, "TODO");

int key_info_len =
call_get_merkle_leaf_element(dc,
st->wallet_header_keys_info_merkle_root,
st->wallet_header_n_keys,
keyexpr_info->key_expression_ptr->k.key_index,
key_info_str,
sizeof(key_info_str));

if (key_info_len < 0) {
SEND_SW(dc, SW_BAD_STATE); // should never happen
Expand Down Expand Up @@ -775,7 +780,7 @@ static bool find_first_internal_keyexpr(dispatcher_context_t *dc,
int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map,
keyexpr_info->cur_index,
NULL,
&keyexpr_info->key_expression);
&keyexpr_info->key_expression_ptr);
if (n_key_expressions < 0) {
SEND_SW(dc, SW_BAD_STATE); // should never happen
return false;
Expand Down Expand Up @@ -1884,9 +1889,9 @@ static bool __attribute__((noinline)) sign_sighash_ecdsa_and_yield(dispatcher_co
for (int i = 0; i < keyexpr_info->key_derivation_length; i++) {
sign_path[i] = keyexpr_info->key_derivation[i];
}
sign_path[keyexpr_info->key_derivation_length] = input->in_out.is_change
? keyexpr_info->key_expression.num_second
: keyexpr_info->key_expression.num_first;
sign_path[keyexpr_info->key_derivation_length] =
input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second
: keyexpr_info->key_expression_ptr->num_first;
sign_path[keyexpr_info->key_derivation_length + 1] = input->in_out.address_index;

int sign_path_len = keyexpr_info->key_derivation_length + 2;
Expand Down Expand Up @@ -1953,8 +1958,8 @@ static bool __attribute__((noinline)) sign_sighash_schnorr_and_yield(dispatcher_
sign_path[i] = keyexpr_info->key_derivation[i];
}
sign_path[keyexpr_info->key_derivation_length] =
input->in_out.is_change ? keyexpr_info->key_expression.num_second
: keyexpr_info->key_expression.num_first;
input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second
: keyexpr_info->key_expression_ptr->num_first;
sign_path[keyexpr_info->key_derivation_length + 1] = input->in_out.address_index;

int sign_path_len = keyexpr_info->key_derivation_length + 2;
Expand Down Expand Up @@ -2343,8 +2348,8 @@ static bool __attribute__((noinline)) fill_taproot_keyexpr_info(dispatcher_conte
const input_info_t *input,
const policy_node_t *tapleaf_ptr,
keyexpr_info_t *keyexpr_info) {
uint32_t change = input->in_out.is_change ? keyexpr_info->key_expression.num_second
: keyexpr_info->key_expression.num_first;
uint32_t change = input->in_out.is_change ? keyexpr_info->key_expression_ptr->num_second
: keyexpr_info->key_expression_ptr->num_first;
uint32_t address_index = input->in_out.address_index;

cx_sha256_t hash_context;
Expand Down Expand Up @@ -2413,7 +2418,7 @@ sign_transaction(dispatcher_context_t *dc,
int n_key_expressions = get_keyexpr_by_index(st->wallet_policy_map,
key_expression_index,
&tapleaf_ptr,
&keyexpr_info.key_expression);
&keyexpr_info.key_expression_ptr);

if (n_key_expressions < 0) {
SEND_SW(dc, SW_BAD_STATE); // should never happen
Expand Down

0 comments on commit ef7dd97

Please sign in to comment.