Skip to content

Commit

Permalink
🚸 zb,zm: Use std::num::NonZeroU32 for serial numbers in pub API
Browse files Browse the repository at this point in the history
This is to ensure that the serial number is never 0.

This breaks API in multiple places but typical user will not be affected
by this.
  • Loading branch information
zeenix committed Aug 5, 2023
1 parent 22d9ced commit de653f2
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 44 deletions.
15 changes: 10 additions & 5 deletions zbus/src/blocking/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use enumflags2::BitFlags;
use event_listener::EventListener;
use static_assertions::assert_impl_all;
use std::{io, ops::Deref, sync::Arc};
use std::{io, num::NonZeroU32, ops::Deref, sync::Arc};
use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, OwnedUniqueName, WellKnownName};
use zvariant::ObjectPath;

Expand Down Expand Up @@ -69,7 +69,7 @@ impl Connection {
/// The connection sets a unique serial number on the message before sending it off.
///
/// On successfully sending off `msg`, the assigned serial number is returned.
pub fn send_message(&self, msg: Message) -> Result<u32> {
pub fn send_message(&self, msg: Message) -> Result<NonZeroU32> {
block_on(self.inner.send_message(msg))
}

Expand Down Expand Up @@ -143,7 +143,7 @@ impl Connection {
/// given `body`.
///
/// Returns the message serial number.
pub fn reply<B>(&self, call: &Message, body: &B) -> Result<u32>
pub fn reply<B>(&self, call: &Message, body: &B) -> Result<NonZeroU32>
where
B: serde::ser::Serialize + zvariant::DynamicType,
{
Expand All @@ -156,7 +156,12 @@ impl Connection {
/// with the given `error_name` and `body`.
///
/// Returns the message serial number.
pub fn reply_error<'e, E, B>(&self, call: &Message, error_name: E, body: &B) -> Result<u32>
pub fn reply_error<'e, E, B>(
&self,
call: &Message,
error_name: E,
body: &B,
) -> Result<NonZeroU32>
where
B: serde::ser::Serialize + zvariant::DynamicType,
E: TryInto<ErrorName<'e>>,
Expand All @@ -175,7 +180,7 @@ impl Connection {
&self,
call: &zbus::message::Header<'_>,
err: impl DBusError,
) -> Result<u32> {
) -> Result<NonZeroU32> {
block_on(self.inner.reply_dbus_error(call, err))
}

Expand Down
18 changes: 11 additions & 7 deletions zbus/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use static_assertions::assert_impl_all;
use std::{
collections::HashMap,
io::{self, ErrorKind},
num::NonZeroU32,
ops::Deref,
pin::Pin,
sync::{
Expand Down Expand Up @@ -226,7 +227,7 @@ assert_impl_all!(Connection: Send, Sync, Unpin);
#[derive(Debug)]
pub(crate) struct PendingMethodCall {
stream: Option<MessageStream>,
serial: u32,
serial: NonZeroU32,
}

impl Future for PendingMethodCall {
Expand Down Expand Up @@ -299,7 +300,7 @@ impl Connection {
/// number on the message before sending it off, for you.
///
/// On successfully sending off `msg`, the assigned serial number is returned.
pub async fn send_message(&self, mut msg: Message) -> Result<u32> {
pub async fn send_message(&self, mut msg: Message) -> Result<NonZeroU32> {
let serial = self.assign_serial_num(&mut msg)?;

trace!("Sending message: {:?}", msg);
Expand Down Expand Up @@ -447,7 +448,7 @@ impl Connection {
/// given `body`.
///
/// Returns the message serial number.
pub async fn reply<B>(&self, call: &Message, body: &B) -> Result<u32>
pub async fn reply<B>(&self, call: &Message, body: &B) -> Result<NonZeroU32>
where
B: serde::ser::Serialize + zvariant::DynamicType,
{
Expand All @@ -466,7 +467,7 @@ impl Connection {
call: &Message,
error_name: E,
body: &B,
) -> Result<u32>
) -> Result<NonZeroU32>
where
B: serde::ser::Serialize + zvariant::DynamicType,
E: TryInto<ErrorName<'e>>,
Expand All @@ -486,7 +487,7 @@ impl Connection {
&self,
call: &zbus::message::Header<'_>,
err: impl DBusError,
) -> Result<u32> {
) -> Result<NonZeroU32> {
let m = err.create_reply(call);
self.send_message(m?).await
}
Expand Down Expand Up @@ -793,8 +794,11 @@ impl Connection {
/// Assigns a serial number to `msg` that is unique to this connection.
///
/// This method can fail if `msg` is corrupted.
pub fn assign_serial_num(&self, msg: &mut Message) -> Result<u32> {
let serial = self.next_serial();
pub fn assign_serial_num(&self, msg: &mut Message) -> Result<NonZeroU32> {
let serial = self
.next_serial()
.try_into()
.map_err(|_| Error::InvalidSerial)?;
msg.set_serial_num(serial)?;

Ok(serial)
Expand Down
2 changes: 1 addition & 1 deletion zbus/src/fdo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ mod tests {
#[test]
fn error_from_zerror() {
let mut m = Message::method(Some(":1.2"), None::<()>, "/", None::<()>, "foo", &()).unwrap();
m.set_serial_num(1).unwrap();
m.set_serial_num(1.try_into().unwrap()).unwrap();
let m = Message::method_error(
None::<()>,
&m,
Expand Down
6 changes: 3 additions & 3 deletions zbus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,13 @@ mod tests {
assert_eq!(m.member().unwrap(), "GetMachineId");
m.modify_primary_header(|primary| {
primary.set_flags(BitFlags::from(Flags::NoAutoStart));
primary.set_serial_num(11);
primary.set_serial_num(11.try_into().unwrap());

Ok(())
})
.unwrap();
let primary = m.primary_header();
assert!(*primary.serial_num().unwrap() == 11);
assert!(primary.serial_num().unwrap().get() == 11);
assert!(primary.flags() == Flags::NoAutoStart);
}

Expand Down Expand Up @@ -617,7 +617,7 @@ mod tests {
for m in stream {
let msg = m.unwrap();

if *msg.primary_header().serial_num().unwrap() == serial {
if msg.primary_header().serial_num().unwrap() == serial {
break;
}
}
Expand Down
4 changes: 1 addition & 3 deletions zbus/src/message/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,7 @@ impl<'a> Builder<'a> {

fn reply_to(mut self, reply_to: &Header<'_>) -> Result<Self> {
let serial = reply_to.primary().serial_num().ok_or(Error::MissingField)?;
self.header
.fields_mut()
.replace(Field::ReplySerial(*serial));
self.header.fields_mut().replace(Field::ReplySerial(serial));

if let Some(sender) = reply_to.sender()? {
self.destination(sender.to_owned())
Expand Down
11 changes: 8 additions & 3 deletions zbus/src/message/field.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use serde::{
de::{Deserialize, Deserializer, Error},
ser::{Serialize, Serializer},
Expand Down Expand Up @@ -103,7 +105,7 @@ pub enum Field<'f> {
/// The name of the error that occurred, for errors
ErrorName(ErrorName<'f>),
/// The serial number of the message this message is a reply to.
ReplySerial(u32),
ReplySerial(NonZeroU32),
/// The name of the connection this message is intended for.
Destination(BusName<'f>),
/// Unique name of the sending connection.
Expand Down Expand Up @@ -132,7 +134,7 @@ impl<'f> Serialize for Field<'f> {
Field::Interface(value) => (FieldCode::Interface, value.as_str().into()),
Field::Member(value) => (FieldCode::Member, value.as_str().into()),
Field::ErrorName(value) => (FieldCode::ErrorName, value.as_str().into()),
Field::ReplySerial(value) => (FieldCode::ReplySerial, (*value).into()),
Field::ReplySerial(value) => (FieldCode::ReplySerial, value.get().into()),
Field::Destination(value) => (FieldCode::Destination, value.as_str().into()),
Field::Sender(value) => (FieldCode::Sender, value.as_str().into()),
Field::Signature(value) => (FieldCode::Signature, value.as_ref().into()),
Expand Down Expand Up @@ -165,7 +167,10 @@ impl<'de: 'f, 'f> Deserialize<'de> for Field<'f> {
.map_err(D::Error::custom)?,
),
FieldCode::ReplySerial => {
Field::ReplySerial(u32::try_from(value).map_err(D::Error::custom)?)
let value = u32::try_from(value)
.map_err(D::Error::custom)
.and_then(|v| v.try_into().map_err(D::Error::custom))?;
Field::ReplySerial(value)
}
FieldCode::Destination => Field::Destination(
BusName::try_from(value)
Expand Down
13 changes: 7 additions & 6 deletions zbus/src/message/fields.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use serde::{Deserialize, Serialize};
use static_assertions::assert_impl_all;
use std::num::NonZeroU32;
use zbus_names::{InterfaceName, MemberName};
use zvariant::{ObjectPath, Type};

Expand Down Expand Up @@ -154,7 +155,7 @@ pub(crate) struct QuickFields {
path: FieldPos,
interface: FieldPos,
member: FieldPos,
reply_serial: Option<u32>,
reply_serial: Option<NonZeroU32>,
}

impl QuickFields {
Expand All @@ -179,7 +180,7 @@ impl QuickFields {
self.member.read(msg.as_bytes())
}

pub fn reply_serial(&self) -> Option<u32> {
pub fn reply_serial(&self) -> Option<NonZeroU32> {
self.reply_serial
}
}
Expand All @@ -206,16 +207,16 @@ mod tests {
fn test() {
let mut mf = Fields::new();
assert_eq!(mf.len(), 0);
mf.add(Field::ReplySerial(42));
mf.add(Field::ReplySerial(42.try_into().unwrap()));
assert_eq!(mf.len(), 1);
mf.add(Field::ReplySerial(43));
mf.add(Field::ReplySerial(43.try_into().unwrap()));
assert_eq!(mf.len(), 2);

let mut mf = Fields::new();
assert_eq!(mf.len(), 0);
mf.replace(Field::ReplySerial(42));
mf.replace(Field::ReplySerial(42.try_into().unwrap()));
assert_eq!(mf.len(), 1);
mf.replace(Field::ReplySerial(43));
mf.replace(Field::ReplySerial(43.try_into().unwrap()));
assert_eq!(mf.len(), 1);
}
}
25 changes: 14 additions & 11 deletions zbus/src/message/header.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::num::NonZeroU32;

use enumflags2::{bitflags, BitFlags};
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
Expand Down Expand Up @@ -201,15 +203,12 @@ impl PrimaryHeader {
///
/// **Note:** There is no setter provided for this in the public API since this is set by the
/// [`Connection`](struct.Connection.html) the message is sent over.
pub fn serial_num(&self) -> Option<&u32> {
match &self.serial_num {
0 => None,
serial_num => Some(serial_num),
}
pub fn serial_num(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.serial_num)
}

pub(crate) fn set_serial_num(&mut self, serial_num: u32) {
self.serial_num = serial_num;
pub(crate) fn set_serial_num(&mut self, serial_num: NonZeroU32) {
self.serial_num = serial_num.get();
}
}

Expand Down Expand Up @@ -310,8 +309,12 @@ impl<'m> Header<'m> {
}

/// The serial number of the message this message is a reply to.
pub fn reply_serial(&self) -> Result<Option<u32>, Error> {
get_field_u32!(self, ReplySerial)
pub fn reply_serial(&self) -> Result<Option<NonZeroU32>, Error> {
match self.fields().get_field(FieldCode::ReplySerial) {
Some(Field::ReplySerial(value)) => Ok(Some(*value)),
Some(_) => Err(Error::InvalidField),
None => Ok(None),
}
}

/// The name of the connection this message is intended for.
Expand Down Expand Up @@ -370,7 +373,7 @@ mod tests {
let mut f = Fields::new();
f.add(Field::ErrorName("org.zbus.Error".try_into()?));
f.add(Field::Destination(":1.11".try_into()?));
f.add(Field::ReplySerial(88));
f.add(Field::ReplySerial(88.try_into()?));
f.add(Field::Signature(Signature::from_str_unchecked("say")));
f.add(Field::UnixFDs(12));
let h = Header::new(PrimaryHeader::new(Type::MethodReturn, 77), f);
Expand All @@ -381,7 +384,7 @@ mod tests {
assert_eq!(h.member()?, None);
assert_eq!(h.error_name()?.unwrap(), "org.zbus.Error");
assert_eq!(h.destination()?.unwrap(), ":1.11");
assert_eq!(h.reply_serial()?, Some(88));
assert_eq!(h.reply_serial()?.map(Into::into), Some(88));
assert_eq!(h.sender()?, None);
assert_eq!(h.signature()?, Some(&Signature::from_str_unchecked("say")));
assert_eq!(h.unix_fds()?, Some(12));
Expand Down
12 changes: 8 additions & 4 deletions zbus/src/message/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
//! D-Bus Message.
use std::{fmt, io::Cursor};
use std::{
fmt,
io::Cursor,
num::NonZeroU32,
};

#[cfg(unix)]
use std::{
Expand Down Expand Up @@ -355,7 +359,7 @@ impl Message {
}

/// The serial number of the message this message is a reply to.
pub fn reply_serial(&self) -> Option<u32> {
pub fn reply_serial(&self) -> Option<NonZeroU32> {
self.quick_fields.reply_serial()
}

Expand Down Expand Up @@ -475,7 +479,7 @@ impl Message {
self.recv_seq
}

pub(crate) fn set_serial_num(&mut self, serial_num: u32) -> Result<()> {
pub(crate) fn set_serial_num(&mut self, serial_num: NonZeroU32) -> Result<()> {
self.modify_primary_header(|primary| {
primary.set_serial_num(serial_num);
Ok(())
Expand Down Expand Up @@ -607,7 +611,7 @@ mod tests {
),
)
.unwrap();
m.set_serial_num(1).unwrap();
m.set_serial_num(1.try_into().unwrap()).unwrap();
assert_eq!(
m.body_signature().unwrap().to_string(),
if cfg!(unix) { "hs" } else { "s" }
Expand Down
2 changes: 1 addition & 1 deletion zbus_macros/src/iface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ pub fn expand(args: AttributeArgs, mut input: ItemImpl) -> syn::Result<TokenStre
#reply
};
#zbus::object_server::DispatchResult::Async(::std::boxed::Box::pin(async move {
future.await.map(|_seq: u32| ())
future.await.map(|_seq: ::std::num::NonZeroU32| ())
}))
},
};
Expand Down

0 comments on commit de653f2

Please sign in to comment.