Skip to content

Commit

Permalink
Switch from const generic to associated const for verify key len.
Browse files Browse the repository at this point in the history
This saves consumers of the Aggregator trait from needing to specify an
L of their own.
  • Loading branch information
branlwyd committed Oct 14, 2022
1 parent 7b8a9bc commit 7b5ae53
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 11 deletions.
11 changes: 7 additions & 4 deletions src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,13 @@ where
}

/// The Aggregator's role in the execution of a VDAF.
pub trait Aggregator<const L: usize>: Vdaf
pub trait Aggregator: Vdaf
where
for<'a> &'a Self::AggregateShare: Into<Vec<u8>>,
{
/// The length of a verification key, in bytes.
const VERIFY_KEY_LEN: usize;

/// State of the Aggregator during the Prepare process.
type PrepareState: Clone + Debug;

Expand All @@ -194,7 +197,7 @@ where
/// message.
fn prepare_init(
&self,
verify_key: &[u8; L],
verify_key: &[u8; Self::VERIFY_KEY_LEN],
agg_id: usize,
agg_param: &Self::AggregationParam,
nonce: &[u8],
Expand All @@ -219,7 +222,7 @@ where
&self,
state: Self::PrepareState,
input: Self::PrepareMessage,
) -> Result<PrepareTransition<Self, L>, VdafError>;
) -> Result<PrepareTransition<Self>, VdafError>;

/// Aggregates a sequence of output shares into an aggregate share.
fn aggregate<M: IntoIterator<Item = Self::OutputShare>>(
Expand All @@ -245,7 +248,7 @@ where

/// A state transition of an Aggregator during the Prepare process.
#[derive(Debug)]
pub enum PrepareTransition<V: Aggregator<L>, const L: usize>
pub enum PrepareTransition<V: Aggregator>
where
for<'a> &'a V::AggregateShare: Into<Vec<u8>>,
{
Expand Down
6 changes: 4 additions & 2 deletions src/vdaf/poplar1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,11 +438,13 @@ fn get_level(agg_param: &BTreeSet<IdpfInput>) -> Result<usize, VdafError> {
}
}

impl<I, P, const L: usize> Aggregator<L> for Poplar1<I, P, L>
impl<I, P, const L: usize> Aggregator for Poplar1<I, P, L>
where
I: Idpf<2, 2>,
P: Prg<L>,
{
const VERIFY_KEY_LEN: usize = L;

type PrepareState = Poplar1PrepareState<I::Field>;
type PrepareShare = Poplar1PrepareMessage<I::Field>;
type PrepareMessage = Poplar1PrepareMessage<I::Field>;
Expand Down Expand Up @@ -571,7 +573,7 @@ where
&self,
mut state: Poplar1PrepareState<I::Field>,
msg: Poplar1PrepareMessage<I::Field>,
) -> Result<PrepareTransition<Self, L>, VdafError> {
) -> Result<PrepareTransition<Self>, VdafError> {
match &state.sketch {
SketchState::RoundOne => {
if msg.0.len() != 3 {
Expand Down
8 changes: 5 additions & 3 deletions src/vdaf/prio2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,16 @@ impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare {
}
}

impl Aggregator<32> for Prio2 {
impl Aggregator for Prio2 {
const VERIFY_KEY_LEN: usize = 32;

type PrepareState = Prio2PrepareState;
type PrepareShare = Prio2PrepareShare;
type PrepareMessage = ();

fn prepare_init(
&self,
agg_key: &[u8; 32],
agg_key: &[u8; Self::VERIFY_KEY_LEN],
agg_id: usize,
_agg_param: &(),
nonce: &[u8],
Expand Down Expand Up @@ -239,7 +241,7 @@ impl Aggregator<32> for Prio2 {
&self,
state: Prio2PrepareState,
_input: (),
) -> Result<PrepareTransition<Self, 32>, VdafError> {
) -> Result<PrepareTransition<Self>, VdafError> {
let data = match state.0 {
Share::Leader(data) => data,
Share::Helper(seed) => {
Expand Down
6 changes: 4 additions & 2 deletions src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,11 +764,13 @@ where
}
}

impl<T, P, const L: usize> Aggregator<L> for Prio3<T, P, L>
impl<T, P, const L: usize> Aggregator for Prio3<T, P, L>
where
T: Type,
P: Prg<L>,
{
const VERIFY_KEY_LEN: usize = L;

type PrepareState = Prio3PrepareState<T::Field, L>;
type PrepareShare = Prio3PrepareShare<T::Field, L>;
type PrepareMessage = Prio3PrepareMessage<L>;
Expand Down Expand Up @@ -943,7 +945,7 @@ where
&self,
step: Prio3PrepareState<T::Field, L>,
msg: Prio3PrepareMessage<L>,
) -> Result<PrepareTransition<Self, L>, VdafError> {
) -> Result<PrepareTransition<Self>, VdafError> {
if self.typ.joint_rand_len() > 0 {
// Check that the joint randomness was correct.
if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() {
Expand Down

0 comments on commit 7b5ae53

Please sign in to comment.