Skip to content

Fixes UB of recvmmsg and simplifies signatures of recvmmsg and sendmmsg #2120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ This project adheres to [Semantic Versioning](https://semver.org/).
- Fix `SigSet` incorrect implementation of `Eq`, `PartialEq` and `Hash`
([#1946](https://github.com/nix-rust/nix/pull/1946))

- Fixed the function signature of `recvmmsg`, potentially causing UB
([#2119](https://github.com/nix-rust/nix/issues/2119))

### Changed

- The following APIs now take an implementation of `AsFd` rather than a
Expand All @@ -33,6 +36,8 @@ This project adheres to [Semantic Versioning](https://semver.org/).
relaxed lifetime requirements relative to 0.27.1.
([#2136](https://github.com/nix-rust/nix/pull/2136))

- Simplified the function signatures of `recvmmsg` and `sendmmsg`

## [0.27.1] - 2023-08-28

### Fixed
Expand Down
44 changes: 28 additions & 16 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1552,23 +1552,23 @@ pub fn sendmmsg<'a, XS, AS, C, I, S>(
flags: MsgFlags
) -> crate::Result<MultiResults<'a, S>>
where
XS: IntoIterator<Item = &'a I>,
XS: IntoIterator<Item = I>,
AS: AsRef<[Option<S>]>,
I: AsRef<[IoSlice<'a>]> + 'a,
C: AsRef<[ControlMessage<'a>]> + 'a,
S: SockaddrLike + 'a
I: AsRef<[IoSlice<'a>]>,
C: AsRef<[ControlMessage<'a>]>,
S: SockaddrLike,
{

let mut count = 0;


for (i, ((slice, addr), mmsghdr)) in slices.into_iter().zip(addrs.as_ref()).zip(data.items.iter_mut() ).enumerate() {
let p = &mut mmsghdr.msg_hdr;
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
p.msg_iov = slice.as_ref().as_ptr().cast_mut().cast();
p.msg_iovlen = slice.as_ref().len() as _;

p.msg_namelen = addr.as_ref().map_or(0, S::len);
p.msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr) as _;
p.msg_name = addr.as_ref().map_or(ptr::null(), S::as_ptr).cast_mut().cast();

// Encode each cmsg. This must happen after initializing the header because
// CMSG_NEXT_HDR and friends read the msg_control and msg_controllen fields.
Expand All @@ -1583,9 +1583,16 @@ pub fn sendmmsg<'a, XS, AS, C, I, S>(
pmhdr = unsafe { CMSG_NXTHDR(p, pmhdr) };
}

count = i+1;
// Doing an unchecked addition is alright here, as the only way to obtain an instance of `MultiHeaders`
// is through the `preallocate` function, which takes an `usize` as an argument to define its size,
// which also provides an upper bound for the size of this zipped iterator. Thus, `i < usize::MAX` or in
// other words: `count` doesn't overflow
count = i + 1;
}

// SAFETY: all pointers are guaranteed to be valid for the scope of this function. `count` does represent the
// maximum number of messages that can be sent safely (i.e. `count` is the minimum of the sizes of `slices`,
// `data.items` and `addrs`)
let sent = Errno::result(unsafe {
libc::sendmmsg(
fd,
Expand Down Expand Up @@ -1711,21 +1718,28 @@ pub fn recvmmsg<'a, XS, S, I>(
mut timeout: Option<crate::sys::time::TimeSpec>,
) -> crate::Result<MultiResults<'a, S>>
where
XS: IntoIterator<Item = &'a I>,
I: AsRef<[IoSliceMut<'a>]> + 'a,
XS: IntoIterator<Item = I>,
I: AsMut<[IoSliceMut<'a>]>,
{
let mut count = 0;
for (i, (slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() {
for (i, (mut slice, mmsghdr)) in slices.into_iter().zip(data.items.iter_mut()).enumerate() {
let p = &mut mmsghdr.msg_hdr;
p.msg_iov = slice.as_ref().as_ptr() as *mut libc::iovec;
p.msg_iovlen = slice.as_ref().len() as _;
p.msg_iov = slice.as_mut().as_mut_ptr().cast();
p.msg_iovlen = slice.as_mut().len() as _;

// Doing an unchecked addition is alright here, as the only way to obtain an instance of `MultiHeaders`
// is through the `preallocate` function, which takes an `usize` as an argument to define its size,
// which also provides an upper bound for the size of this zipped iterator. Thus, `i < usize::MAX` or in
// other words: `count` doesn't overflow
count = i + 1;
}

let timeout_ptr = timeout
.as_mut()
.map_or_else(std::ptr::null_mut, |t| t as *mut _ as *mut libc::timespec);

// SAFETY: all pointers are guaranteed to be valid for the scope of this function. `count` does represent the
// maximum number of messages that can be received safely (i.e. `count` is the minimum of the sizes of `slices` and `data.items`)
let received = Errno::result(unsafe {
libc::recvmmsg(
fd,
Expand All @@ -1743,16 +1757,14 @@ where
})
}

/// Iterator over results of [`recvmmsg`]/[`sendmmsg`]
#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
))]
#[derive(Debug)]
/// Iterator over results of [`recvmmsg`]/[`sendmmsg`]
///
///
pub struct MultiResults<'a, S> {
// preallocated structures
rmm: &'a MultiHeaders<S>,
Expand Down Expand Up @@ -1903,7 +1915,7 @@ mod test {

let t = sys::time::TimeSpec::from_duration(std::time::Duration::from_secs(10));

let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter(), flags, Some(t))?;
let recv = super::recvmmsg(rsock.as_raw_fd(), &mut data, recv_iovs.iter_mut(), flags, Some(t))?;

for rmsg in recv {
#[cfg(not(any(qemu, target_arch = "aarch64")))]
Expand Down
19 changes: 12 additions & 7 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ mod recvfrom {
let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(
rsock.as_raw_fd(),
&mut data,
msgs.iter(),
msgs.iter_mut(),
MsgFlags::empty(),
None,
)
Expand Down Expand Up @@ -652,7 +652,7 @@ mod recvfrom {
let res: Vec<RecvMsg<SockaddrIn>> = recvmmsg(
rsock.as_raw_fd(),
&mut data,
msgs.iter(),
msgs.iter_mut(),
MsgFlags::MSG_DONTWAIT,
None,
)
Expand Down Expand Up @@ -2324,12 +2324,17 @@ fn test_recvmmsg_timestampns() {
// Receive the message
let mut buffer = vec![0u8; message.len()];
let cmsgspace = nix::cmsg_space!(TimeSpec);
let iov = vec![[IoSliceMut::new(&mut buffer)]];
let mut iov = vec![[IoSliceMut::new(&mut buffer)]];
let mut data = MultiHeaders::preallocate(1, Some(cmsgspace));
let r: Vec<RecvMsg<()>> =
recvmmsg(in_socket.as_raw_fd(), &mut data, iov.iter(), flags, None)
.unwrap()
.collect();
let r: Vec<RecvMsg<()>> = recvmmsg(
in_socket.as_raw_fd(),
&mut data,
iov.iter_mut(),
flags,
None,
)
.unwrap()
.collect();
let rtime = match r[0].cmsgs().next() {
Some(ControlMessageOwned::ScmTimestampns(rtime)) => rtime,
Some(_) => panic!("Unexpected control message"),
Expand Down