fix(nix-daemon): ensure Framed NARs are read exactly

This prevents framing confusion, which would otherwise lead to a
trivial confused deputy attack. See issue #120.

The NixFramedReader state machine has been refactored to simplify
its internal logic and accurately account for EOF conditions.

End-of-stream is fused, and unexpected EOF on the underlying reader
is returned as UnexpectedEof, though we don't fuse those ourselves.

We also ensure that the underlying reader does not swap the ReadBuf;
this would otherwise supply a primitive for converting uninitialised
mutable memory into `&mut [u8]` without initialisation, thus allowing
undefined behaviour to be triggered from safe code.

Change-Id: I05ddb7e3ca57b3363f56c0d9b43d5a641748ca36
Reviewed-on: https://cl.snix.dev/c/snix/+/30380
Reviewed-by: Brian Olsen <brian@maven-group.org>
Tested-by: besadii
Reviewed-by: Florian Klink <flokli@flokli.de>
This commit is contained in:
edef 2025-05-09 16:25:27 +00:00
parent 4ef7c50a2d
commit 9a8a9c6b67
2 changed files with 107 additions and 104 deletions

View file

@ -1,45 +1,35 @@
use std::{
io::Result,
num::NonZeroU64,
pin::Pin,
task::{ready, Poll},
task::{self, ready, Poll},
};
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{self, AsyncRead, ReadBuf};
/// State machine for [`NixFramedReader`].
///
/// As the reader progresses it linearly cycles through the states.
#[derive(Debug, PartialEq)]
enum NixFramedReaderState {
/// The reader always starts in this state.
///
/// Before the payload, the client first sends its size.
/// The size is a u64 which is 8 bytes long, while it's likely that we will receive
/// the whole u64 in one read, it's possible that it will arrive in smaller chunks.
/// So in this state we read up to 8 bytes and transition to
/// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
/// otherwise we reset filled to 0, and read the next size value.
ReadingSize { buf: [u8; 8], filled: usize },
/// This is where we read the actual payload that is sent to us.
///
/// Once we've read the expected number of bytes, we go back to the
/// [`NixFramedReaderState::ReadingSize`] state.
ReadingPayload {
/// Represents the remaining number of bytes we expect to read based on the value
/// read in the previous state.
remaining: u64,
},
/// We read length-prefixed chunks until we receive a zero-sized payload indicating EOF.
/// Other than the zero-sized terminating chunk, chunk boundaries are not considered meaningful.
/// Lengths are 64-bit little endian values on the wire.
#[derive(Debug, Eq, PartialEq)]
enum State {
Length { buf: [u8; 8], filled: u8 },
Chunk { remaining: NonZeroU64 },
Eof,
}
pin_project! {
/// Implements Nix's Framed reader protocol for protocol versions >= 1.23.
/// Implements Nix's [Framed] reader protocol for protocol versions >= 1.23.
///
/// See serialization.md#framed and [`NixFramedReaderState`] for details.
/// Unexpected EOF on the underlying reader is returned as [UnexpectedEof][`std::io::ErrorKind::UnexpectedEof`].
/// True EOF (end-of-stream) is fused.
///
/// [Framed]: https://snix.dev/docs/reference/nix-daemon-protocol/types/#framed
pub struct NixFramedReader<R> {
#[pin]
reader: R,
state: NixFramedReaderState,
state: State,
}
}
@ -47,100 +37,107 @@ impl<R> NixFramedReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
state: NixFramedReaderState::ReadingSize {
state: State::Length {
buf: [0; 8],
filled: 0,
},
}
}
#[cfg(test)]
fn is_eof(&self) -> bool {
self.state
== NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 8,
}
/// Returns `true` if the Nix Framed reader has reached EOF.
#[must_use]
pub fn is_eof(&self) -> bool {
matches!(self.state, State::Eof)
}
}
impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
read_buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut this = self.as_mut().project();
// reading nothing always succeeds
if buf.remaining() == 0 {
return Ok(()).into();
}
loop {
let reader = this.reader.as_mut();
match this.state {
NixFramedReaderState::ReadingSize { buf, filled } => {
if *filled < buf.len() {
let mut size_buf = ReadBuf::new(buf);
size_buf.advance(*filled);
State::Eof => {
return Ok(()).into();
}
State::Length { buf, filled } => {
let bytes_read = {
let mut b = ReadBuf::new(&mut buf[*filled as usize..]);
ready!(reader.poll_read(cx, &mut b))?;
b.filled().len() as u8
};
ready!(this.reader.poll_read(cx, &mut size_buf))?;
let bytes_read = size_buf.filled().len() - *filled;
if bytes_read == 0 {
// oef
return Poll::Ready(Ok(()));
return Err(io::ErrorKind::UnexpectedEof.into()).into();
}
*filled += bytes_read;
// Schedule ourselves to run again.
return self.poll_read(cx, read_buf);
if *filled == 8 {
*this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
None => State::Eof,
Some(remaining) => State::Chunk { remaining },
};
}
let size = u64::from_le_bytes(*buf);
if size == 0 {
// eof
*filled = 0;
return Poll::Ready(Ok(()));
}
*this.state = NixFramedReaderState::ReadingPayload { remaining: size };
self.poll_read(cx, read_buf)
}
NixFramedReaderState::ReadingPayload { remaining } => {
// Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
let safe_remaining = if *remaining <= usize::MAX as u64 {
*remaining as usize
State::Chunk { remaining } => {
let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| {
reader.poll_read(cx, buf).map_ok(|()| buf.filled().len())
}))?;
*this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) {
None => State::Length {
buf: [0; 8],
filled: 0,
},
Some(remaining) => State::Chunk { remaining },
};
return if bytes_read == 0 {
Err(io::ErrorKind::UnexpectedEof.into())
} else {
usize::MAX
};
if safe_remaining > 0 {
// The buffer is no larger than the amount of data that we expect.
// Otherwise we will trim the buffer below and come back here.
if read_buf.remaining() <= safe_remaining {
let filled_before = read_buf.filled().len();
Ok(())
}
.into();
}
}
}
}
}
ready!(this.reader.as_mut().poll_read(cx, read_buf))?;
let bytes_read = read_buf.filled().len() - filled_before;
/// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it.
/// After `f` returns, we propagate the filled cursor advancement back to `buf`.
// TODO(edef): duplicate of src/wire/bytes/reader/mod.rs:with_limited
fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
let ptr = nbuf.initialized().as_ptr();
let ret = f(&mut nbuf);
*remaining -= bytes_read as u64;
if *remaining == 0 {
*this.state = NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 0,
};
}
return Poll::Ready(Ok(()));
}
// Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
// internal bookkeeping simpler.
let mut smaller_buf = read_buf.take(safe_remaining);
ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
// SAFETY: `ReadBuf::take` only returns the *unfilled* section of `buf`,
// so anything filled is new, initialized data.
//
// We verify that `nbuf` still points to the same buffer,
// so we're sure it hasn't been swapped out.
unsafe {
// ensure our buffer hasn't been swapped out
assert_eq!(nbuf.initialized().as_ptr(), ptr);
let bytes_read = smaller_buf.filled().len();
let n = nbuf.filled().len();
buf.assume_init(n);
buf.advance(n);
}
// SAFETY: we just read this number of bytes into read_buf's backing slice above.
unsafe { read_buf.assume_init(bytes_read) };
read_buf.advance(bytes_read);
return Poll::Ready(Ok(()));
}
*this.state = NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 0,
};
self.poll_read(cx, read_buf)
}
}
}
ret
}
#[cfg(test)]
@ -158,7 +155,6 @@ mod nix_framed_tests {
use crate::nix_daemon::framing::NixFramedReader;
#[tokio::test]
#[should_panic] // broken
async fn read_unexpected_eof_after_frame() {
let mut mock = Builder::new()
// The client sends len
@ -179,7 +175,6 @@ mod nix_framed_tests {
}
#[tokio::test]
#[should_panic] // broken
async fn read_unexpected_eof_in_frame() {
let mut mock = Builder::new()
// The client sends len
@ -200,7 +195,6 @@ mod nix_framed_tests {
}
#[tokio::test]
#[should_panic] // broken
async fn read_unexpected_eof_in_length() {
let mut mock = Builder::new()
// The client sends len
@ -219,7 +213,6 @@ mod nix_framed_tests {
}
#[tokio::test]
#[should_panic] // broken
async fn read_hello_world_in_two_frames() {
let mut mock = Builder::new()
// The client sends len
@ -268,7 +261,6 @@ mod nix_framed_tests {
/// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input,
/// independent of how it is spread across read calls and poll cycles.
#[test]
#[should_panic] // broken
fn split_verif() {
let mut cx = task::Context::from_waker(task::Waker::noop());
let mut input = make_framed(&[b"hello", b"world", b"!", b""]);

View file

@ -189,16 +189,27 @@ where
writer.deref_mut(),
);
self.io.add_to_store_nar(request, &mut reader).await
// TODO(edef): enforce framing synchronisation
})
.await?
}
23.. => {
// Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed
let mut framed = NixFramedReader::new(&mut self.reader);
Self::handle(&self.writer, async {
self.io.add_to_store_nar(request, &mut framed).await
})
.await?
.await?;
// framing desynchronisation
// this MUST kill the connection
if !framed.is_eof() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"payload was not fully consumed",
));
}
}
}
}