fix(nix-daemon): ensure Framed NARs are read exactly

This prevents framing confusion, which would otherwise lead to a
trivial confused deputy attack. See issue #120.

The NixFramedReader state machine has been refactored to simplify
its internal logic and accurately account for EOF conditions.

End-of-stream is fused, and unexpected EOF on the underlying reader
is returned as UnexpectedEof, though we don't fuse those ourselves.

We also ensure that the underlying reader does not swap the ReadBuf;
this would otherwise supply a primitive for converting uninitialised
mutable memory into `&mut [u8]` without initialisation, thus allowing
undefined behaviour to be triggered from safe code.

Change-Id: I05ddb7e3ca57b3363f56c0d9b43d5a641748ca36
Reviewed-on: https://cl.snix.dev/c/snix/+/30380
Reviewed-by: Brian Olsen <brian@maven-group.org>
Tested-by: besadii
Reviewed-by: Florian Klink <flokli@flokli.de>
This commit is contained in:
edef 2025-05-09 16:25:27 +00:00
parent 4ef7c50a2d
commit 9a8a9c6b67
2 changed files with 107 additions and 104 deletions

View file

@ -1,45 +1,35 @@
use std::{ use std::{
io::Result, num::NonZeroU64,
pin::Pin, pin::Pin,
task::{ready, Poll}, task::{self, ready, Poll},
}; };
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, ReadBuf}; use tokio::io::{self, AsyncRead, ReadBuf};
/// State machine for [`NixFramedReader`]. /// State machine for [`NixFramedReader`].
/// ///
/// As the reader progresses it linearly cycles through the states. /// We read length-prefixed chunks until we receive a zero-sized payload indicating EOF.
#[derive(Debug, PartialEq)] /// Other than the zero-sized terminating chunk, chunk boundaries are not considered meaningful.
enum NixFramedReaderState { /// Lengths are 64-bit little endian values on the wire.
/// The reader always starts in this state. #[derive(Debug, Eq, PartialEq)]
/// enum State {
/// Before the payload, the client first sends its size. Length { buf: [u8; 8], filled: u8 },
/// The size is a u64 which is 8 bytes long, while it's likely that we will receive Chunk { remaining: NonZeroU64 },
/// the whole u64 in one read, it's possible that it will arrive in smaller chunks. Eof,
/// So in this state we read up to 8 bytes and transition to
/// [`NixFramedReaderState::ReadingPayload`] when done if the read size is not zero,
/// otherwise we reset filled to 0, and read the next size value.
ReadingSize { buf: [u8; 8], filled: usize },
/// This is where we read the actual payload that is sent to us.
///
/// Once we've read the expected number of bytes, we go back to the
/// [`NixFramedReaderState::ReadingSize`] state.
ReadingPayload {
/// Represents the remaining number of bytes we expect to read based on the value
/// read in the previous state.
remaining: u64,
},
} }
pin_project! { pin_project! {
/// Implements Nix's Framed reader protocol for protocol versions >= 1.23. /// Implements Nix's [Framed] reader protocol for protocol versions >= 1.23.
/// ///
/// See serialization.md#framed and [`NixFramedReaderState`] for details. /// Unexpected EOF on the underlying reader is returned as [UnexpectedEof][`std::io::ErrorKind::UnexpectedEof`].
/// True EOF (end-of-stream) is fused.
///
/// [Framed]: https://snix.dev/docs/reference/nix-daemon-protocol/types/#framed
pub struct NixFramedReader<R> { pub struct NixFramedReader<R> {
#[pin] #[pin]
reader: R, reader: R,
state: NixFramedReaderState, state: State,
} }
} }
@ -47,100 +37,107 @@ impl<R> NixFramedReader<R> {
pub fn new(reader: R) -> Self { pub fn new(reader: R) -> Self {
Self { Self {
reader, reader,
state: NixFramedReaderState::ReadingSize { state: State::Length {
buf: [0; 8], buf: [0; 8],
filled: 0, filled: 0,
}, },
} }
} }
#[cfg(test)] /// Returns `true` if the Nix Framed reader has reached EOF.
fn is_eof(&self) -> bool { #[must_use]
self.state pub fn is_eof(&self) -> bool {
== NixFramedReaderState::ReadingSize { matches!(self.state, State::Eof)
buf: [0; 8],
filled: 8,
}
} }
} }
impl<R: AsyncRead> AsyncRead for NixFramedReader<R> { impl<R: AsyncRead> AsyncRead for NixFramedReader<R> {
fn poll_read( fn poll_read(
mut self: Pin<&mut Self>, mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut task::Context<'_>,
read_buf: &mut ReadBuf<'_>, buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> { ) -> Poll<io::Result<()>> {
let mut this = self.as_mut().project(); let mut this = self.as_mut().project();
// reading nothing always succeeds
if buf.remaining() == 0 {
return Ok(()).into();
}
loop {
let reader = this.reader.as_mut();
match this.state { match this.state {
NixFramedReaderState::ReadingSize { buf, filled } => { State::Eof => {
if *filled < buf.len() { return Ok(()).into();
let mut size_buf = ReadBuf::new(buf); }
size_buf.advance(*filled); State::Length { buf, filled } => {
let bytes_read = {
let mut b = ReadBuf::new(&mut buf[*filled as usize..]);
ready!(reader.poll_read(cx, &mut b))?;
b.filled().len() as u8
};
ready!(this.reader.poll_read(cx, &mut size_buf))?;
let bytes_read = size_buf.filled().len() - *filled;
if bytes_read == 0 { if bytes_read == 0 {
// oef return Err(io::ErrorKind::UnexpectedEof.into()).into();
return Poll::Ready(Ok(()));
} }
*filled += bytes_read; *filled += bytes_read;
// Schedule ourselves to run again.
return self.poll_read(cx, read_buf); if *filled == 8 {
*this.state = match NonZeroU64::new(u64::from_le_bytes(*buf)) {
None => State::Eof,
Some(remaining) => State::Chunk { remaining },
};
} }
let size = u64::from_le_bytes(*buf);
if size == 0 {
// eof
*filled = 0;
return Poll::Ready(Ok(()));
} }
*this.state = NixFramedReaderState::ReadingPayload { remaining: size }; State::Chunk { remaining } => {
self.poll_read(cx, read_buf) let bytes_read = ready!(with_limited(buf, remaining.get(), |buf| {
} reader.poll_read(cx, buf).map_ok(|()| buf.filled().len())
NixFramedReaderState::ReadingPayload { remaining } => { }))?;
// Make sure we never try to read more than usize which is 4 bytes on 32-bit platforms.
let safe_remaining = if *remaining <= usize::MAX as u64 { *this.state = match NonZeroU64::new(remaining.get() - bytes_read as u64) {
*remaining as usize None => State::Length {
buf: [0; 8],
filled: 0,
},
Some(remaining) => State::Chunk { remaining },
};
return if bytes_read == 0 {
Err(io::ErrorKind::UnexpectedEof.into())
} else { } else {
usize::MAX Ok(())
}; }
if safe_remaining > 0 { .into();
// The buffer is no larger than the amount of data that we expect. }
// Otherwise we will trim the buffer below and come back here. }
if read_buf.remaining() <= safe_remaining { }
let filled_before = read_buf.filled().len(); }
}
ready!(this.reader.as_mut().poll_read(cx, read_buf))?; /// Make a limited version of `buf`, consisting only of up to `n` bytes of the unfilled section, and call `f` with it.
let bytes_read = read_buf.filled().len() - filled_before; /// After `f` returns, we propagate the filled cursor advancement back to `buf`.
// TODO(edef): duplicate of src/wire/bytes/reader/mod.rs:with_limited
fn with_limited<R>(buf: &mut ReadBuf, n: u64, f: impl FnOnce(&mut ReadBuf) -> R) -> R {
let mut nbuf = buf.take(n.try_into().unwrap_or(usize::MAX));
let ptr = nbuf.initialized().as_ptr();
let ret = f(&mut nbuf);
*remaining -= bytes_read as u64; // SAFETY: `ReadBuf::take` only returns the *unfilled* section of `buf`,
if *remaining == 0 { // so anything filled is new, initialized data.
*this.state = NixFramedReaderState::ReadingSize { //
buf: [0; 8], // We verify that `nbuf` still points to the same buffer,
filled: 0, // so we're sure it hasn't been swapped out.
}; unsafe {
} // ensure our buffer hasn't been swapped out
return Poll::Ready(Ok(())); assert_eq!(nbuf.initialized().as_ptr(), ptr);
}
// Don't read more than remaining + pad bytes, it avoids unnecessary allocations and makes
// internal bookkeeping simpler.
let mut smaller_buf = read_buf.take(safe_remaining);
ready!(self.as_mut().poll_read(cx, &mut smaller_buf))?;
let bytes_read = smaller_buf.filled().len(); let n = nbuf.filled().len();
buf.assume_init(n);
buf.advance(n);
}
// SAFETY: we just read this number of bytes into read_buf's backing slice above. ret
unsafe { read_buf.assume_init(bytes_read) };
read_buf.advance(bytes_read);
return Poll::Ready(Ok(()));
}
*this.state = NixFramedReaderState::ReadingSize {
buf: [0; 8],
filled: 0,
};
self.poll_read(cx, read_buf)
}
}
}
} }
#[cfg(test)] #[cfg(test)]
@ -158,7 +155,6 @@ mod nix_framed_tests {
use crate::nix_daemon::framing::NixFramedReader; use crate::nix_daemon::framing::NixFramedReader;
#[tokio::test] #[tokio::test]
#[should_panic] // broken
async fn read_unexpected_eof_after_frame() { async fn read_unexpected_eof_after_frame() {
let mut mock = Builder::new() let mut mock = Builder::new()
// The client sends len // The client sends len
@ -179,7 +175,6 @@ mod nix_framed_tests {
} }
#[tokio::test] #[tokio::test]
#[should_panic] // broken
async fn read_unexpected_eof_in_frame() { async fn read_unexpected_eof_in_frame() {
let mut mock = Builder::new() let mut mock = Builder::new()
// The client sends len // The client sends len
@ -200,7 +195,6 @@ mod nix_framed_tests {
} }
#[tokio::test] #[tokio::test]
#[should_panic] // broken
async fn read_unexpected_eof_in_length() { async fn read_unexpected_eof_in_length() {
let mut mock = Builder::new() let mut mock = Builder::new()
// The client sends len // The client sends len
@ -219,7 +213,6 @@ mod nix_framed_tests {
} }
#[tokio::test] #[tokio::test]
#[should_panic] // broken
async fn read_hello_world_in_two_frames() { async fn read_hello_world_in_two_frames() {
let mut mock = Builder::new() let mut mock = Builder::new()
// The client sends len // The client sends len
@ -268,7 +261,6 @@ mod nix_framed_tests {
/// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input, /// Somewhat of a fuzz test, ensuring that we end up in identical states for the same input,
/// independent of how it is spread across read calls and poll cycles. /// independent of how it is spread across read calls and poll cycles.
#[test] #[test]
#[should_panic] // broken
fn split_verif() { fn split_verif() {
let mut cx = task::Context::from_waker(task::Waker::noop()); let mut cx = task::Context::from_waker(task::Waker::noop());
let mut input = make_framed(&[b"hello", b"world", b"!", b""]); let mut input = make_framed(&[b"hello", b"world", b"!", b""]);

View file

@ -189,16 +189,27 @@ where
writer.deref_mut(), writer.deref_mut(),
); );
self.io.add_to_store_nar(request, &mut reader).await self.io.add_to_store_nar(request, &mut reader).await
// TODO(edef): enforce framing synchronisation
}) })
.await? .await?
} }
23.. => { 23.. => {
// Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed // Starting at protocol version 1.23, the framed protocol is used, see serialization.md#framed
let mut framed = NixFramedReader::new(&mut self.reader); let mut framed = NixFramedReader::new(&mut self.reader);
Self::handle(&self.writer, async { Self::handle(&self.writer, async {
self.io.add_to_store_nar(request, &mut framed).await self.io.add_to_store_nar(request, &mut framed).await
}) })
.await? .await?;
// framing desynchronisation
// this MUST kill the connection
if !framed.is_eof() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"payload was not fully consumed",
));
}
} }
} }
} }